Skip to main content

rustauth_plugins/custom_session/
mod.rs

1//! Custom session plugin.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use http::{header, StatusCode};
8use rustauth_core::api::{ApiRequest, ApiResponse};
9use rustauth_core::context::AuthContext;
10use rustauth_core::error::RustAuthError;
11use rustauth_core::plugin::{AuthPlugin, PluginAfterHookAction, PluginAfterHookFuture};
12use serde::Serialize;
13use serde_json::Value;
14
15pub const UPSTREAM_PLUGIN_ID: &str = "custom-session";
16
17/// Options for the custom session plugin.
18#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
19#[serde(rename_all = "camelCase")]
20pub struct CustomSessionOptions {
21    pub should_mutate_list_device_sessions_endpoint: bool,
22}
23
24impl CustomSessionOptions {
25    #[must_use]
26    pub fn builder() -> CustomSessionOptionsBuilder {
27        CustomSessionOptionsBuilder::default()
28    }
29}
30
31#[derive(Debug, Clone, Copy, Default)]
32pub struct CustomSessionOptionsBuilder {
33    should_mutate_list_device_sessions_endpoint: Option<bool>,
34}
35
36impl CustomSessionOptionsBuilder {
37    #[must_use]
38    pub fn should_mutate_list_device_sessions_endpoint(mut self, enabled: bool) -> Self {
39        self.should_mutate_list_device_sessions_endpoint = Some(enabled);
40        self
41    }
42
43    #[must_use]
44    pub fn build(self) -> CustomSessionOptions {
45        let defaults = CustomSessionOptions::default();
46        CustomSessionOptions {
47            should_mutate_list_device_sessions_endpoint: self
48                .should_mutate_list_device_sessions_endpoint
49                .unwrap_or(defaults.should_mutate_list_device_sessions_endpoint),
50        }
51    }
52}
53
54/// Session payload passed to the custom session handler.
55#[derive(Debug, Clone, PartialEq)]
56pub struct CustomSessionInput {
57    pub user: Value,
58    pub session: Value,
59}
60
61/// Request context available to custom session handlers.
62#[derive(Clone, Copy)]
63pub struct CustomSessionContext<'a> {
64    pub auth_context: &'a AuthContext,
65    pub request: &'a ApiRequest,
66}
67
68pub type CustomSessionFuture<'a> =
69    Pin<Box<dyn Future<Output = Result<Value, RustAuthError>> + Send + 'a>>;
70
71type CustomSessionHandler = Arc<
72    dyn for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
73        + Send
74        + Sync,
75>;
76
77/// Create a custom session plugin with options and request-aware handler.
78#[must_use]
79pub fn custom_session<F>(options: CustomSessionOptions, handler: F) -> AuthPlugin
80where
81    F: for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
82        + Send
83        + Sync
84        + 'static,
85{
86    let handler: CustomSessionHandler = Arc::new(handler);
87    let mut plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID)
88        .with_version(env!("CARGO_PKG_VERSION"))
89        .with_options(serde_json::to_value(options).unwrap_or(Value::Null))
90        .with_async_after_hook("/get-session", {
91            let handler = Arc::clone(&handler);
92            move |context, request, response| {
93                transform_get_session_response(&handler, context, request, response)
94            }
95        });
96
97    if options.should_mutate_list_device_sessions_endpoint {
98        plugin = plugin.with_async_after_hook("/multi-session/list-device-sessions", {
99            let handler = Arc::clone(&handler);
100            move |context, request, response| {
101                transform_list_device_sessions_response(&handler, context, request, response)
102            }
103        });
104    }
105
106    plugin
107}
108
109fn transform_get_session_response<'a>(
110    handler: &CustomSessionHandler,
111    auth_context: &'a AuthContext,
112    request: &'a ApiRequest,
113    response: ApiResponse,
114) -> PluginAfterHookFuture<'a> {
115    let handler = Arc::clone(handler);
116    Box::pin(async move {
117        if response.status() != StatusCode::OK {
118            return Ok(PluginAfterHookAction::Continue(response));
119        }
120        let (parts, body) = response.into_parts();
121        let value = response_json(&body)?;
122        if value.is_null() {
123            return Ok(PluginAfterHookAction::Continue(ApiResponse::from_parts(
124                parts, body,
125            )));
126        }
127        let input = custom_session_input(value)?;
128        let custom = handler(
129            input,
130            CustomSessionContext {
131                auth_context,
132                request,
133            },
134        )
135        .await?;
136        Ok(PluginAfterHookAction::Continue(json_response(
137            parts, &custom,
138        )?))
139    })
140}
141
142fn transform_list_device_sessions_response<'a>(
143    handler: &CustomSessionHandler,
144    auth_context: &'a AuthContext,
145    request: &'a ApiRequest,
146    response: ApiResponse,
147) -> PluginAfterHookFuture<'a> {
148    let handler = Arc::clone(handler);
149    Box::pin(async move {
150        if response.status() != StatusCode::OK {
151            return Ok(PluginAfterHookAction::Continue(response));
152        }
153        let (parts, body) = response.into_parts();
154        let value = response_json(&body)?;
155        let Some(sessions) = value.as_array() else {
156            return Err(RustAuthError::Api(
157                "custom-session expected list-device-sessions response to be an array".to_owned(),
158            ));
159        };
160        let mut custom_sessions = Vec::with_capacity(sessions.len());
161        for session in sessions {
162            let input = custom_session_input(session.clone())?;
163            custom_sessions.push(
164                handler(
165                    input,
166                    CustomSessionContext {
167                        auth_context,
168                        request,
169                    },
170                )
171                .await?,
172            );
173        }
174        Ok(PluginAfterHookAction::Continue(json_response(
175            parts,
176            &Value::Array(custom_sessions),
177        )?))
178    })
179}
180
181fn custom_session_input(value: Value) -> Result<CustomSessionInput, RustAuthError> {
182    let Value::Object(mut object) = value else {
183        return Err(RustAuthError::Api(
184            "custom-session expected session response to be an object".to_owned(),
185        ));
186    };
187    let Some(user) = object.remove("user") else {
188        return Err(RustAuthError::Api(
189            "custom-session expected session response to include user".to_owned(),
190        ));
191    };
192    let Some(session) = object.remove("session") else {
193        return Err(RustAuthError::Api(
194            "custom-session expected session response to include session".to_owned(),
195        ));
196    };
197    Ok(CustomSessionInput { user, session })
198}
199
200fn response_json(body: &[u8]) -> Result<Value, RustAuthError> {
201    serde_json::from_slice(body).map_err(|error| RustAuthError::Api(error.to_string()))
202}
203
204fn json_response(
205    mut parts: http::response::Parts,
206    body: &Value,
207) -> Result<ApiResponse, RustAuthError> {
208    parts.headers.insert(
209        header::CONTENT_TYPE,
210        http::HeaderValue::from_static("application/json"),
211    );
212    parts.headers.remove(header::CONTENT_LENGTH);
213    let body = serde_json::to_vec(body).map_err(|error| RustAuthError::Api(error.to_string()))?;
214    Ok(ApiResponse::from_parts(parts, body))
215}