Skip to main content

rustauth_core/context/
builder.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4#[cfg(feature = "oauth")]
5use rustauth_oauth::oauth2::SocialOAuthProvider;
6
7use time::Duration;
8
9use crate::background::tokio::TokioBackgroundTaskRunner;
10use crate::cookies::get_cookies;
11use crate::crypto::password::{hash_password, verify_password};
12use crate::crypto::{build_secret_config, parse_secrets_env};
13use crate::db::RateLimitStorage as DbRateLimitStorage;
14use crate::db::{auth_schema, AuthSchemaOptions, DbAdapter, DbField, HookedAdapter};
15use crate::env::is_production_posture;
16use crate::env::logger::create_logger;
17use crate::error::RustAuthError;
18use crate::options::hooks::{plugin_after_hooks, plugin_before_hooks};
19use crate::options::RateLimitStore;
20use crate::options::{
21    plugin_database_hooks_from_init, BackgroundTaskRunner, ModelSchemaOptions,
22    RateLimitStorageOption, RustAuthOptions, SessionAdditionalField, UserAdditionalField,
23};
24use crate::plugin::AuthPlugin;
25use crate::rate_limit::{GovernorMemoryRateLimitStore, LegacyRateLimitStorageAdapter};
26
27use super::origins::resolve_trusted_origins;
28use super::plugins::initialize_plugins;
29use super::secrets::{resolve_legacy_secret, validate_secret, DEFAULT_SECRET};
30use super::{
31    noop_telemetry_publisher, AuthContext, AuthEnvironment, PasswordContext, PasswordPolicy,
32    RateLimitContext, SecretMaterial, SessionConfig,
33};
34
35pub fn create_auth_context(options: RustAuthOptions) -> Result<AuthContext, RustAuthError> {
36    create_auth_context_with_environment_and_adapter(options, AuthEnvironment::from_process(), None)
37}
38
39pub fn create_auth_context_with_adapter(
40    options: RustAuthOptions,
41    adapter: Arc<dyn DbAdapter>,
42) -> Result<AuthContext, RustAuthError> {
43    create_auth_context_with_environment_and_adapter(
44        options,
45        AuthEnvironment::from_process(),
46        Some(adapter),
47    )
48}
49
50pub fn create_auth_context_with_environment(
51    options: RustAuthOptions,
52    environment: AuthEnvironment,
53) -> Result<AuthContext, RustAuthError> {
54    create_auth_context_with_environment_and_adapter(options, environment, None)
55}
56
57pub fn create_auth_context_with_environment_and_adapter(
58    options: RustAuthOptions,
59    environment: AuthEnvironment,
60    adapter: Option<Arc<dyn DbAdapter>>,
61) -> Result<AuthContext, RustAuthError> {
62    let logger = create_logger(options.logger.clone());
63    let production_posture = is_production_posture(&options);
64    let env_secrets = parse_secrets_env(environment.rustauth_secrets.as_deref())?;
65    let secrets = if options.secrets.is_empty() {
66        env_secrets.unwrap_or_default()
67    } else {
68        options.secrets.clone()
69    };
70    let legacy_secret = resolve_legacy_secret(&options, &environment);
71
72    let (secret, secret_config) = if secrets.is_empty() {
73        let secret = legacy_secret.unwrap_or_else(|| DEFAULT_SECRET.to_owned());
74        validate_secret(&secret, &options)?;
75        (secret.clone(), SecretMaterial::Single(secret))
76    } else {
77        let config = build_secret_config(&secrets, legacy_secret.as_deref().unwrap_or(""))?;
78        let current = config
79            .keys
80            .get(&config.current_version)
81            .cloned()
82            .ok_or_else(|| {
83                RustAuthError::InvalidSecretConfig(format!(
84                    "secret version {} not found in keys",
85                    config.current_version
86                ))
87            })?;
88        (current, SecretMaterial::Rotating(config))
89    };
90
91    let base_path = options
92        .base_path
93        .clone()
94        .unwrap_or_else(|| "/api/auth".to_owned());
95    let base_url = options.base_url.clone().unwrap_or_default();
96    let trusted_origins = resolve_trusted_origins(&base_url, &options, &environment);
97    let auth_cookies = get_cookies(&options)?;
98    #[cfg(feature = "oauth")]
99    let social_providers = resolve_social_providers(&options)?;
100    let session_config = SessionConfig {
101        update_age: options.session.update_age.unwrap_or(Duration::hours(24)),
102        expires_in: options.session.expires_in.unwrap_or(Duration::days(7)),
103        fresh_age: options.session.fresh_age.unwrap_or(Duration::days(1)),
104        cookie_refresh_cache: options.session.cookie_cache.refresh_cache,
105    };
106    let password = PasswordContext {
107        config: PasswordPolicy {
108            min_password_length: options.password.min_password_length,
109            max_password_length: options.password.max_password_length,
110        },
111        hash: options.password.hash_password.unwrap_or(hash_password),
112        verify: options.password.verify_password.unwrap_or(verify_password),
113    };
114    validate_rate_limit_storage(&options)?;
115    let rate_limit = RateLimitContext {
116        enabled: options.rate_limit.enabled.unwrap_or(production_posture),
117        window: options.rate_limit.window,
118        max: options.rate_limit.max,
119        storage: options.rate_limit.storage,
120        custom_rules: options.rate_limit.custom_rules.clone(),
121        dynamic_rules: options.rate_limit.dynamic_rules.clone(),
122        plugin_rules: Vec::new(),
123        custom_store: options.rate_limit.custom_store.clone().or_else(|| {
124            options.rate_limit.custom_storage.clone().map(|storage| {
125                Arc::new(LegacyRateLimitStorageAdapter::new(storage)) as Arc<dyn RateLimitStore>
126            })
127        }),
128        hybrid: options.rate_limit.hybrid.clone(),
129        memory_cleanup_interval: options.rate_limit.memory_cleanup_interval,
130        memory_store: Arc::new(GovernorMemoryRateLimitStore::with_cleanup_interval(
131            options.rate_limit.memory_cleanup_interval,
132        )),
133        missing_ip_policy: options.rate_limit.missing_ip_policy,
134    };
135
136    let schema_options = schema_options_from_auth_options(&options);
137    let app_name = options
138        .app_name
139        .clone()
140        .unwrap_or_else(|| "RustAuth".to_owned());
141    let mut context =
142        AuthContext {
143            app_name,
144            base_url,
145            base_path,
146            options: options.clone(),
147            auth_cookies,
148            session_config,
149            secret,
150            secret_config,
151            password,
152            rate_limit,
153            trusted_origins,
154            disabled_paths: options.disabled_paths,
155            plugins: options.plugins,
156            adapter,
157            secondary_storage: options.secondary_storage.clone(),
158            background_tasks: options.advanced.background_tasks.clone().or_else(|| {
159                Some(Arc::new(TokioBackgroundTaskRunner) as Arc<dyn BackgroundTaskRunner>)
160            }),
161            #[cfg(feature = "oauth")]
162            social_providers,
163            db_schema: auth_schema(schema_options),
164            plugin_error_codes: BTreeMap::new(),
165            plugin_database_hooks: {
166                let mut hooks = plugin_database_hooks_from_init(&options.init_database_hooks);
167                hooks.extend(options.database_hooks.clone());
168                hooks
169            },
170            plugin_migrations: Vec::new(),
171            telemetry_publisher: noop_telemetry_publisher(),
172            logger,
173        };
174    apply_global_hooks(&mut context);
175    initialize_plugins(&mut context)?;
176    if !context.plugin_database_hooks.is_empty() {
177        if let Some(adapter) = context.adapter.clone() {
178            context.adapter = Some(Arc::new(HookedAdapter::with_logger(
179                adapter,
180                context.plugin_database_hooks.clone(),
181                context.logger.clone(),
182            )));
183        }
184    }
185    Ok(context)
186}
187
188fn apply_global_hooks(context: &mut AuthContext) {
189    let before = plugin_before_hooks(&context.options.hooks);
190    let after = plugin_after_hooks(&context.options.hooks);
191    if before.is_empty() && after.is_empty() {
192        return;
193    }
194    let mut plugin = AuthPlugin::new("__rustauth_global__");
195    plugin.hooks.before = before;
196    plugin.hooks.after = after;
197    context.plugins.insert(0, plugin);
198}
199
200#[cfg(feature = "oauth")]
201fn resolve_social_providers(
202    options: &RustAuthOptions,
203) -> Result<BTreeMap<String, Arc<dyn SocialOAuthProvider>>, RustAuthError> {
204    let mut providers = BTreeMap::new();
205    for provider in &options.social_providers {
206        insert_social_provider(&mut providers, provider.clone())?;
207    }
208    Ok(providers)
209}
210
211#[cfg(feature = "oauth")]
212pub(super) fn insert_social_provider(
213    providers: &mut BTreeMap<String, Arc<dyn SocialOAuthProvider>>,
214    provider: Arc<dyn SocialOAuthProvider>,
215) -> Result<(), RustAuthError> {
216    let id = provider.id().to_owned();
217    if id.trim().is_empty() {
218        return Err(RustAuthError::InvalidConfig(
219            "social provider id cannot be empty".to_owned(),
220        ));
221    }
222    if providers.insert(id.clone(), provider).is_some() {
223        return Err(RustAuthError::InvalidConfig(format!(
224            "duplicate social provider `{id}`"
225        )));
226    }
227    Ok(())
228}
229
230fn validate_rate_limit_storage(options: &RustAuthOptions) -> Result<(), RustAuthError> {
231    if options.rate_limit.custom_store.is_some() || options.rate_limit.custom_storage.is_some() {
232        return Ok(());
233    }
234    if matches!(
235        options.rate_limit.storage,
236        RateLimitStorageOption::Database | RateLimitStorageOption::SecondaryStorage
237    ) {
238        return Err(RustAuthError::InvalidConfig(
239            "rate_limit.custom_store or rate_limit.custom_storage is required when using database or secondary-storage rate limiting without a concrete adapter".to_owned(),
240        ));
241    }
242    Ok(())
243}
244
245fn apply_model_schema(table: &mut crate::db::TableOptions, schema: &ModelSchemaOptions) {
246    table.name = schema.model_name.clone();
247    table.field_names = schema
248        .field_names
249        .iter()
250        .map(|(key, value)| (key.clone(), value.clone()))
251        .collect();
252}
253
254fn schema_options_from_auth_options(options: &RustAuthOptions) -> AuthSchemaOptions {
255    let mut schema_options = AuthSchemaOptions {
256        has_secondary_storage: options.secondary_storage.is_some(),
257        store_session_in_database: options.session.store_session_in_database,
258        rate_limit_storage: match options.rate_limit.storage {
259            RateLimitStorageOption::Memory => DbRateLimitStorage::Memory,
260            RateLimitStorageOption::Database => DbRateLimitStorage::Database,
261            RateLimitStorageOption::SecondaryStorage => DbRateLimitStorage::SecondaryStorage,
262        },
263        ..AuthSchemaOptions::default()
264    };
265    apply_model_schema(&mut schema_options.user, &options.user.schema);
266    apply_model_schema(&mut schema_options.session, &options.session.schema);
267    apply_model_schema(&mut schema_options.account, &options.account.schema);
268    apply_model_schema(
269        &mut schema_options.verification,
270        &options.verification.schema,
271    );
272    apply_model_schema(&mut schema_options.rate_limit, &options.rate_limit.schema);
273    for (name, field) in &options.user.additional_fields {
274        schema_options
275            .user
276            .additional_fields
277            .insert(name.clone(), user_additional_field_to_db_field(name, field));
278    }
279    for (name, field) in &options.session.additional_fields {
280        schema_options.session.additional_fields.insert(
281            name.clone(),
282            session_additional_field_to_db_field(name, field),
283        );
284    }
285    schema_options
286}
287
288pub(super) fn user_additional_field_to_db_field(
289    logical_name: &str,
290    field: &UserAdditionalField,
291) -> DbField {
292    additional_field_to_db_field(
293        logical_name,
294        field.db_name.as_deref(),
295        field.field_type.clone(),
296        field.required,
297        field.input,
298        field.returned,
299    )
300}
301
302pub(super) fn session_additional_field_to_db_field(
303    logical_name: &str,
304    field: &SessionAdditionalField,
305) -> DbField {
306    additional_field_to_db_field(
307        logical_name,
308        field.db_name.as_deref(),
309        field.field_type.clone(),
310        field.required,
311        field.input,
312        field.returned,
313    )
314}
315
316fn additional_field_to_db_field(
317    logical_name: &str,
318    db_name: Option<&str>,
319    field_type: crate::db::DbFieldType,
320    required: bool,
321    input: bool,
322    returned: bool,
323) -> DbField {
324    let mut field = DbField::new(db_name.unwrap_or(logical_name), field_type);
325    if !required {
326        field = field.optional();
327    }
328    if !input {
329        field = field.generated();
330    }
331    if !returned {
332        field = field.hidden();
333    }
334    field
335}