Skip to main content

rustauth_core/context/
request_state.rs

1use std::any::Any;
2use std::cell::RefCell;
3use std::collections::HashMap;
4use std::future::Future;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::{Arc, OnceLock};
7
8use crate::db::{Session, User};
9use crate::error::RustAuthError;
10use serde_json::Value;
11
12tokio::task_local! {
13    static REQUEST_STATE: RefCell<RequestStateStore>;
14}
15
16static NEXT_KEY: AtomicU64 = AtomicU64::new(1);
17
18/// Request-scoped state storage.
19#[derive(Default)]
20pub struct RequestStateStore {
21    values: HashMap<RequestStateKey, Box<dyn Any + Send>>,
22}
23
24impl RequestStateStore {
25    pub fn new() -> Self {
26        Self::default()
27    }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub struct RequestStateKey(u64);
32
33impl RequestStateKey {
34    fn new() -> Self {
35        Self(NEXT_KEY.fetch_add(1, Ordering::Relaxed))
36    }
37}
38
39/// Request-scoped typed value.
40pub struct RequestState<T> {
41    key: RequestStateKey,
42    init: Arc<dyn Fn() -> T + Send + Sync>,
43}
44
45impl<T> Clone for RequestState<T> {
46    fn clone(&self) -> Self {
47        Self {
48            key: self.key,
49            init: Arc::clone(&self.init),
50        }
51    }
52}
53
54impl<T> RequestState<T>
55where
56    T: Clone + Send + 'static,
57{
58    /// Get the value for this request, lazily initializing it when absent.
59    pub fn get(&self) -> Result<T, RustAuthError> {
60        with_current_store(|store| {
61            if let Some(value) = store.values.get(&self.key) {
62                return value
63                    .downcast_ref::<T>()
64                    .cloned()
65                    .ok_or(RustAuthError::RequestStateTypeMismatch);
66            }
67
68            let value = (self.init)();
69            store.values.insert(self.key, Box::new(value.clone()));
70            Ok(value)
71        })
72    }
73
74    /// Set the value for this request.
75    pub fn set(&self, value: T) -> Result<(), RustAuthError> {
76        with_current_store(|store| {
77            store.values.insert(self.key, Box::new(value));
78            Ok(())
79        })
80    }
81
82    /// Unique key for debugging or custom stores.
83    pub fn key(&self) -> RequestStateKey {
84        self.key
85    }
86}
87
88/// Define a typed request-scoped state value.
89pub fn define_request_state<T>(init: impl Fn() -> T + Send + Sync + 'static) -> RequestState<T>
90where
91    T: Clone + Send + 'static,
92{
93    RequestState {
94        key: RequestStateKey::new(),
95        init: Arc::new(init),
96    }
97}
98
99static CURRENT_SESSION_USER: OnceLock<RequestState<Option<Value>>> = OnceLock::new();
100static CURRENT_SESSION: OnceLock<RequestState<Option<CurrentSession>>> = OnceLock::new();
101static CURRENT_NEW_SESSION: OnceLock<RequestState<Option<NewSession>>> = OnceLock::new();
102static CURRENT_REQUEST_PATH: OnceLock<RequestState<Option<String>>> = OnceLock::new();
103static REQUEST_IS_EXTERNAL: OnceLock<RequestState<bool>> = OnceLock::new();
104static SHOULD_SKIP_SESSION_REFRESH: OnceLock<RequestState<bool>> = OnceLock::new();
105
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct NewSession {
108    pub session: Session,
109    pub user: User,
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub struct CurrentSession {
114    pub session: Session,
115    pub user: User,
116}
117
118fn current_session_user_state() -> &'static RequestState<Option<Value>> {
119    CURRENT_SESSION_USER.get_or_init(|| define_request_state(|| None))
120}
121
122fn current_request_path_state() -> &'static RequestState<Option<String>> {
123    CURRENT_REQUEST_PATH.get_or_init(|| define_request_state(|| None))
124}
125
126/// Store the current session user JSON for after-response hooks in this request.
127pub fn set_current_session_user(user: Value) -> Result<(), RustAuthError> {
128    current_session_user_state().set(Some(user))
129}
130
131/// Read the current session user JSON for this request, when an endpoint resolved one.
132pub fn current_session_user() -> Result<Option<Value>, RustAuthError> {
133    current_session_user_state().get()
134}
135
136fn current_session_state() -> &'static RequestState<Option<CurrentSession>> {
137    CURRENT_SESSION.get_or_init(|| define_request_state(|| None))
138}
139
140pub fn set_current_session(session: Session, user: User) -> Result<(), RustAuthError> {
141    current_session_state().set(Some(CurrentSession { session, user }))
142}
143
144pub fn current_session() -> Result<Option<CurrentSession>, RustAuthError> {
145    current_session_state().get()
146}
147
148fn current_new_session_state() -> &'static RequestState<Option<NewSession>> {
149    CURRENT_NEW_SESSION.get_or_init(|| define_request_state(|| None))
150}
151
152pub fn set_current_new_session(session: Session, user: User) -> Result<(), RustAuthError> {
153    current_new_session_state().set(Some(NewSession { session, user }))
154}
155
156pub fn current_new_session() -> Result<Option<NewSession>, RustAuthError> {
157    current_new_session_state().get()
158}
159
160/// Store the normalized endpoint path for hooks running in this request.
161pub fn set_current_request_path(path: impl Into<String>) -> Result<(), RustAuthError> {
162    current_request_path_state().set(Some(path.into()))
163}
164
165/// Read the normalized endpoint path for this request, when available.
166pub fn current_request_path() -> Result<Option<String>, RustAuthError> {
167    current_request_path_state().get()
168}
169
170fn request_is_external_state() -> &'static RequestState<bool> {
171    REQUEST_IS_EXTERNAL.get_or_init(|| define_request_state(|| false))
172}
173
174fn should_skip_session_refresh_state() -> &'static RequestState<bool> {
175    SHOULD_SKIP_SESSION_REFRESH.get_or_init(|| define_request_state(|| false))
176}
177
178/// Mark whether the current request originated from the internet-facing HTTP
179/// router. Trusted server-side invocations leave this `false`.
180pub fn set_request_external(external: bool) -> Result<(), RustAuthError> {
181    request_is_external_state().set(external)
182}
183
184/// Returns true only when the current request is known to originate from the
185/// internet-facing HTTP router. Absent request state is treated as a trusted
186/// server-side call (`false`).
187pub fn is_external_request() -> bool {
188    if !has_request_state() {
189        return false;
190    }
191    request_is_external_state().get().unwrap_or(false)
192}
193
194/// Mark whether session resolution should skip refresh for the current request.
195pub fn set_should_skip_session_refresh(skip: bool) -> Result<(), RustAuthError> {
196    should_skip_session_refresh_state().set(skip)
197}
198
199/// Returns true when the current request explicitly disables session refresh.
200pub fn should_skip_session_refresh() -> bool {
201    if !has_request_state() {
202        return false;
203    }
204    should_skip_session_refresh_state().get().unwrap_or(false)
205}
206
207/// Run a future inside a fresh request state scope.
208pub async fn run_with_request_state<F>(future: F) -> F::Output
209where
210    F: Future,
211{
212    REQUEST_STATE
213        .scope(RefCell::new(RequestStateStore::new()), future)
214        .await
215}
216
217/// Returns true when the current async task has request state.
218pub fn has_request_state() -> bool {
219    REQUEST_STATE.try_with(|_| ()).is_ok()
220}
221
222fn with_current_store<T>(
223    operation: impl FnOnce(&mut RequestStateStore) -> Result<T, RustAuthError>,
224) -> Result<T, RustAuthError> {
225    REQUEST_STATE
226        .try_with(|store| operation(&mut store.borrow_mut()))
227        .map_err(|_| RustAuthError::RequestStateMissing)?
228}