Skip to main content

rustauth_plugins/api_key/
mod.rs

1//! API key plugin.
2
3mod cleanup;
4mod errors;
5mod hashing;
6mod models;
7mod options;
8mod organization;
9mod permissions;
10mod rate_limit;
11mod routes;
12mod schema;
13mod storage;
14
15use std::sync::Arc;
16
17use http::{header, StatusCode};
18use rustauth_core::api::output::{session_output_value, user_output_value};
19use rustauth_core::context::request_state;
20use rustauth_core::db::Session;
21use rustauth_core::error::RustAuthError;
22use rustauth_core::plugin::{AuthPlugin, PluginBeforeHookAction};
23use rustauth_core::utils::ip::{is_valid_ip, normalize_ip_with_options, NormalizeIpOptions};
24use serde::Serialize;
25use serde_json::Value;
26use time::OffsetDateTime;
27
28pub use errors::*;
29pub use hashing::default_key_hasher;
30pub use models::{ApiKeyCreateRecord, ApiKeyPublicRecord, ApiKeyRecord};
31pub use options::{
32    ApiKeyConfiguration, ApiKeyExpirationOptions, ApiKeyGenerator, ApiKeyGeneratorInput,
33    ApiKeyGetter, ApiKeyOptions, ApiKeyOptionsBuilder, ApiKeyOptionsError, ApiKeyPermissions,
34    ApiKeyRateLimitOptions, ApiKeyReference, ApiKeyStorageMode, ApiKeyValidator,
35    DefaultPermissionsResolver, StartingCharactersConfig,
36};
37pub use routes::{
38    CreateApiKeyRequest, DeleteApiKeyRequest, GetApiKeyQuery, ListApiKeysQuery,
39    UpdateApiKeyRequest, UpdateField, VerifyApiKeyRequest, VerifyApiKeyResponse,
40};
41pub use schema::ApiKeySchemaOptions;
42
43pub const UPSTREAM_PLUGIN_ID: &str = "api-key";
44pub const API_KEY_MODEL: &str = "api_key";
45pub const API_KEY_TABLE: &str = "api_keys";
46
47pub fn api_key(options: ApiKeyOptions) -> Result<AuthPlugin, RustAuthError> {
48    Ok(build_plugin(options.resolve()?))
49}
50
51fn build_plugin(configurations: options::ResolvedConfigurations) -> AuthPlugin {
52    let schema = configurations.schema().clone();
53    let configurations = Arc::new(configurations);
54    let mut plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID)
55        .with_version(crate::VERSION)
56        .with_schema(schema::schema_contribution(&schema))
57        .with_endpoint(routes::create_endpoint(Arc::clone(&configurations)))
58        .with_endpoint(routes::verify_endpoint(Arc::clone(&configurations)))
59        .with_endpoint(routes::get_endpoint(Arc::clone(&configurations)))
60        .with_endpoint(routes::update_endpoint(Arc::clone(&configurations)))
61        .with_endpoint(routes::delete_endpoint(Arc::clone(&configurations)))
62        .with_endpoint(routes::list_endpoint(Arc::clone(&configurations)))
63        .with_endpoint(routes::delete_expired_endpoint(Arc::clone(&configurations)))
64        .with_async_before_hook("*", move |context, request| {
65            let configurations = Arc::clone(&configurations);
66            Box::pin(async move { session_hook(context, request, configurations).await })
67        });
68    for error_code in errors::plugin_error_codes() {
69        plugin = plugin.with_error_code(error_code);
70    }
71    plugin
72}
73
74async fn session_hook(
75    context: &rustauth_core::context::AuthContext,
76    request: rustauth_core::plugin::PluginRequest,
77    configurations: Arc<options::ResolvedConfigurations>,
78) -> Result<PluginBeforeHookAction, RustAuthError> {
79    let Some((raw_key, options)) = find_session_key(context, &request, &configurations).await?
80    else {
81        return Ok(PluginBeforeHookAction::Continue(request));
82    };
83    if raw_key.len() < options.default_key_length {
84        return errors::error_response(StatusCode::FORBIDDEN, errors::INVALID_API_KEY)
85            .map(PluginBeforeHookAction::Respond);
86    }
87    if let Some(validator) = &options.custom_api_key_validator {
88        if !validator(context, &raw_key).await? {
89            return errors::error_response(StatusCode::FORBIDDEN, errors::INVALID_API_KEY)
90                .map(PluginBeforeHookAction::Respond);
91        }
92    }
93    let hashed = if options.disable_key_hashing {
94        raw_key.clone()
95    } else {
96        hashing::default_key_hasher(&raw_key)
97    };
98    let api_key = match routes::validate_api_key(context, &options, &hashed, None).await {
99        Ok(api_key) => api_key,
100        Err(error) => {
101            return errors::error_response(error.status, error.code)
102                .map(PluginBeforeHookAction::Respond);
103        }
104    };
105    if options.reference != ApiKeyReference::User {
106        return errors::error_response(
107            StatusCode::UNAUTHORIZED,
108            errors::INVALID_REFERENCE_ID_FROM_API_KEY,
109        )
110        .map(PluginBeforeHookAction::Respond);
111    }
112    let Some(_adapter) = context.adapter() else {
113        return Ok(PluginBeforeHookAction::Continue(request));
114    };
115    let Some(user) = context
116        .users()?
117        .find_user_by_id(&api_key.reference_id)
118        .await?
119    else {
120        return errors::error_response(
121            StatusCode::UNAUTHORIZED,
122            errors::INVALID_REFERENCE_ID_FROM_API_KEY,
123        )
124        .map(PluginBeforeHookAction::Respond);
125    };
126    let now = OffsetDateTime::now_utc();
127    let expires_at = match api_key.expires_at {
128        Some(expires_at) => expires_at,
129        None => session_expiration_from_context(context, now)?,
130    };
131    let session = Session {
132        id: api_key.id.clone(),
133        user_id: api_key.reference_id.clone(),
134        expires_at,
135        token: raw_key,
136        ip_address: request_ip(context, &request),
137        user_agent: request
138            .headers()
139            .get(header::USER_AGENT)
140            .and_then(|value| value.to_str().ok())
141            .map(str::to_owned),
142        created_at: now,
143        updated_at: now,
144    };
145    if request_state::has_request_state() {
146        request_state::set_current_session(session.clone(), user.clone())?;
147    }
148    if request.uri().path().ends_with("/get-session") {
149        return session_response(context, session, user)
150            .await
151            .map(PluginBeforeHookAction::Respond);
152    }
153    Ok(PluginBeforeHookAction::Continue(request))
154}
155
156async fn find_session_key(
157    context: &rustauth_core::context::AuthContext,
158    request: &rustauth_core::plugin::PluginRequest,
159    configurations: &options::ResolvedConfigurations,
160) -> Result<Option<(String, ApiKeyConfiguration)>, RustAuthError> {
161    for configuration in configurations
162        .all()
163        .iter()
164        .filter(|configuration| configuration.enable_session_for_api_keys)
165    {
166        if let Some(getter) = &configuration.custom_api_key_getter {
167            if let Some(key) = getter(context, request).await? {
168                return Ok(Some((key, configuration.clone())));
169            }
170            continue;
171        }
172        for header_name in &configuration.api_key_headers {
173            if let Some(value) = request
174                .headers()
175                .get(header_name)
176                .and_then(|value| value.to_str().ok())
177            {
178                return Ok(Some((value.to_owned(), configuration.clone())));
179            }
180        }
181    }
182    Ok(None)
183}
184
185fn request_ip(
186    context: &rustauth_core::context::AuthContext,
187    request: &rustauth_core::plugin::PluginRequest,
188) -> Option<String> {
189    if context.options.advanced.ip_address.disable_ip_tracking {
190        return None;
191    }
192
193    for header_name in &context.options.advanced.ip_address.headers {
194        let Some(value) = request
195            .headers()
196            .get(header_name.as_str())
197            .and_then(|value| value.to_str().ok())
198        else {
199            continue;
200        };
201        for candidate in value.split(',').map(str::trim) {
202            if candidate.is_empty() || !is_valid_ip(candidate) {
203                continue;
204            }
205            return Some(normalize_ip_with_options(
206                candidate,
207                NormalizeIpOptions {
208                    ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
209                },
210            ));
211        }
212    }
213
214    None
215}
216
217fn session_expiration_from_context(
218    context: &rustauth_core::context::AuthContext,
219    now: OffsetDateTime,
220) -> Result<OffsetDateTime, RustAuthError> {
221    now.checked_add(context.session_config.expires_in)
222        .ok_or(RustAuthError::NumericOutOfRange {
223            context: "session.expires_in",
224        })
225}
226
227#[derive(Serialize)]
228struct SessionResponse {
229    session: Value,
230    user: Value,
231}
232
233async fn session_response(
234    context: &rustauth_core::context::AuthContext,
235    session: Session,
236    user: rustauth_core::db::User,
237) -> Result<rustauth_core::plugin::PluginResponse, RustAuthError> {
238    let session = session_output_value(context.adapter_ref()?, context, &session).await?;
239    let user = user_output_value(context.adapter_ref()?, context, &user).await?;
240    let body = serde_json::to_vec(&SessionResponse { session, user })
241        .map_err(|error| RustAuthError::Api(error.to_string()))?;
242    http::Response::builder()
243        .status(StatusCode::OK)
244        .header(header::CONTENT_TYPE, "application/json")
245        .body(body)
246        .map_err(|error| RustAuthError::Api(error.to_string()))
247}