rustauth_plugins/custom_session/
mod.rs1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use http::{header, StatusCode};
8use rustauth_core::api::{ApiRequest, ApiResponse};
9use rustauth_core::context::AuthContext;
10use rustauth_core::error::RustAuthError;
11use rustauth_core::plugin::{AuthPlugin, PluginAfterHookAction, PluginAfterHookFuture};
12use serde::Serialize;
13use serde_json::Value;
14
15pub const UPSTREAM_PLUGIN_ID: &str = "custom-session";
16
17#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
19#[serde(rename_all = "camelCase")]
20pub struct CustomSessionOptions {
21 pub should_mutate_list_device_sessions_endpoint: bool,
22}
23
24impl CustomSessionOptions {
25 #[must_use]
26 pub fn builder() -> CustomSessionOptionsBuilder {
27 CustomSessionOptionsBuilder::default()
28 }
29}
30
31#[derive(Debug, Clone, Copy, Default)]
32pub struct CustomSessionOptionsBuilder {
33 should_mutate_list_device_sessions_endpoint: Option<bool>,
34}
35
36impl CustomSessionOptionsBuilder {
37 #[must_use]
38 pub fn should_mutate_list_device_sessions_endpoint(mut self, enabled: bool) -> Self {
39 self.should_mutate_list_device_sessions_endpoint = Some(enabled);
40 self
41 }
42
43 #[must_use]
44 pub fn build(self) -> CustomSessionOptions {
45 let defaults = CustomSessionOptions::default();
46 CustomSessionOptions {
47 should_mutate_list_device_sessions_endpoint: self
48 .should_mutate_list_device_sessions_endpoint
49 .unwrap_or(defaults.should_mutate_list_device_sessions_endpoint),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq)]
56pub struct CustomSessionInput {
57 pub user: Value,
58 pub session: Value,
59}
60
61#[derive(Clone, Copy)]
63pub struct CustomSessionContext<'a> {
64 pub auth_context: &'a AuthContext,
65 pub request: &'a ApiRequest,
66}
67
68pub type CustomSessionFuture<'a> =
69 Pin<Box<dyn Future<Output = Result<Value, RustAuthError>> + Send + 'a>>;
70
71type CustomSessionHandler = Arc<
72 dyn for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
73 + Send
74 + Sync,
75>;
76
77#[must_use]
79pub fn custom_session<F>(options: CustomSessionOptions, handler: F) -> AuthPlugin
80where
81 F: for<'a> Fn(CustomSessionInput, CustomSessionContext<'a>) -> CustomSessionFuture<'a>
82 + Send
83 + Sync
84 + 'static,
85{
86 let handler: CustomSessionHandler = Arc::new(handler);
87 let mut plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID)
88 .with_version(env!("CARGO_PKG_VERSION"))
89 .with_options(serde_json::to_value(options).unwrap_or(Value::Null))
90 .with_async_after_hook("/get-session", {
91 let handler = Arc::clone(&handler);
92 move |context, request, response| {
93 transform_get_session_response(&handler, context, request, response)
94 }
95 });
96
97 if options.should_mutate_list_device_sessions_endpoint {
98 plugin = plugin.with_async_after_hook("/multi-session/list-device-sessions", {
99 let handler = Arc::clone(&handler);
100 move |context, request, response| {
101 transform_list_device_sessions_response(&handler, context, request, response)
102 }
103 });
104 }
105
106 plugin
107}
108
109fn transform_get_session_response<'a>(
110 handler: &CustomSessionHandler,
111 auth_context: &'a AuthContext,
112 request: &'a ApiRequest,
113 response: ApiResponse,
114) -> PluginAfterHookFuture<'a> {
115 let handler = Arc::clone(handler);
116 Box::pin(async move {
117 if response.status() != StatusCode::OK {
118 return Ok(PluginAfterHookAction::Continue(response));
119 }
120 let (parts, body) = response.into_parts();
121 let value = response_json(&body)?;
122 if value.is_null() {
123 return Ok(PluginAfterHookAction::Continue(ApiResponse::from_parts(
124 parts, body,
125 )));
126 }
127 let input = custom_session_input(value)?;
128 let custom = handler(
129 input,
130 CustomSessionContext {
131 auth_context,
132 request,
133 },
134 )
135 .await?;
136 Ok(PluginAfterHookAction::Continue(json_response(
137 parts, &custom,
138 )?))
139 })
140}
141
142fn transform_list_device_sessions_response<'a>(
143 handler: &CustomSessionHandler,
144 auth_context: &'a AuthContext,
145 request: &'a ApiRequest,
146 response: ApiResponse,
147) -> PluginAfterHookFuture<'a> {
148 let handler = Arc::clone(handler);
149 Box::pin(async move {
150 if response.status() != StatusCode::OK {
151 return Ok(PluginAfterHookAction::Continue(response));
152 }
153 let (parts, body) = response.into_parts();
154 let value = response_json(&body)?;
155 let Some(sessions) = value.as_array() else {
156 return Err(RustAuthError::Api(
157 "custom-session expected list-device-sessions response to be an array".to_owned(),
158 ));
159 };
160 let mut custom_sessions = Vec::with_capacity(sessions.len());
161 for session in sessions {
162 let input = custom_session_input(session.clone())?;
163 custom_sessions.push(
164 handler(
165 input,
166 CustomSessionContext {
167 auth_context,
168 request,
169 },
170 )
171 .await?,
172 );
173 }
174 Ok(PluginAfterHookAction::Continue(json_response(
175 parts,
176 &Value::Array(custom_sessions),
177 )?))
178 })
179}
180
181fn custom_session_input(value: Value) -> Result<CustomSessionInput, RustAuthError> {
182 let Value::Object(mut object) = value else {
183 return Err(RustAuthError::Api(
184 "custom-session expected session response to be an object".to_owned(),
185 ));
186 };
187 let Some(user) = object.remove("user") else {
188 return Err(RustAuthError::Api(
189 "custom-session expected session response to include user".to_owned(),
190 ));
191 };
192 let Some(session) = object.remove("session") else {
193 return Err(RustAuthError::Api(
194 "custom-session expected session response to include session".to_owned(),
195 ));
196 };
197 Ok(CustomSessionInput { user, session })
198}
199
200fn response_json(body: &[u8]) -> Result<Value, RustAuthError> {
201 serde_json::from_slice(body).map_err(|error| RustAuthError::Api(error.to_string()))
202}
203
204fn json_response(
205 mut parts: http::response::Parts,
206 body: &Value,
207) -> Result<ApiResponse, RustAuthError> {
208 parts.headers.insert(
209 header::CONTENT_TYPE,
210 http::HeaderValue::from_static("application/json"),
211 );
212 parts.headers.remove(header::CONTENT_LENGTH);
213 let body = serde_json::to_vec(body).map_err(|error| RustAuthError::Api(error.to_string()))?;
214 Ok(ApiResponse::from_parts(parts, body))
215}