rustauth_core/context/
request_state.rs1use std::any::Any;
2use std::cell::RefCell;
3use std::collections::HashMap;
4use std::future::Future;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::{Arc, OnceLock};
7
8use crate::db::{Session, User};
9use crate::error::RustAuthError;
10use serde_json::Value;
11
12tokio::task_local! {
13 static REQUEST_STATE: RefCell<RequestStateStore>;
14}
15
16static NEXT_KEY: AtomicU64 = AtomicU64::new(1);
17
18#[derive(Default)]
20pub struct RequestStateStore {
21 values: HashMap<RequestStateKey, Box<dyn Any + Send>>,
22}
23
24impl RequestStateStore {
25 pub fn new() -> Self {
26 Self::default()
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub struct RequestStateKey(u64);
32
33impl RequestStateKey {
34 fn new() -> Self {
35 Self(NEXT_KEY.fetch_add(1, Ordering::Relaxed))
36 }
37}
38
39pub struct RequestState<T> {
41 key: RequestStateKey,
42 init: Arc<dyn Fn() -> T + Send + Sync>,
43}
44
45impl<T> Clone for RequestState<T> {
46 fn clone(&self) -> Self {
47 Self {
48 key: self.key,
49 init: Arc::clone(&self.init),
50 }
51 }
52}
53
54impl<T> RequestState<T>
55where
56 T: Clone + Send + 'static,
57{
58 pub fn get(&self) -> Result<T, RustAuthError> {
60 with_current_store(|store| {
61 if let Some(value) = store.values.get(&self.key) {
62 return value
63 .downcast_ref::<T>()
64 .cloned()
65 .ok_or(RustAuthError::RequestStateTypeMismatch);
66 }
67
68 let value = (self.init)();
69 store.values.insert(self.key, Box::new(value.clone()));
70 Ok(value)
71 })
72 }
73
74 pub fn set(&self, value: T) -> Result<(), RustAuthError> {
76 with_current_store(|store| {
77 store.values.insert(self.key, Box::new(value));
78 Ok(())
79 })
80 }
81
82 pub fn key(&self) -> RequestStateKey {
84 self.key
85 }
86}
87
88pub fn define_request_state<T>(init: impl Fn() -> T + Send + Sync + 'static) -> RequestState<T>
90where
91 T: Clone + Send + 'static,
92{
93 RequestState {
94 key: RequestStateKey::new(),
95 init: Arc::new(init),
96 }
97}
98
99static CURRENT_SESSION_USER: OnceLock<RequestState<Option<Value>>> = OnceLock::new();
100static CURRENT_SESSION: OnceLock<RequestState<Option<CurrentSession>>> = OnceLock::new();
101static CURRENT_NEW_SESSION: OnceLock<RequestState<Option<NewSession>>> = OnceLock::new();
102static CURRENT_REQUEST_PATH: OnceLock<RequestState<Option<String>>> = OnceLock::new();
103static REQUEST_IS_EXTERNAL: OnceLock<RequestState<bool>> = OnceLock::new();
104static SHOULD_SKIP_SESSION_REFRESH: OnceLock<RequestState<bool>> = OnceLock::new();
105
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct NewSession {
108 pub session: Session,
109 pub user: User,
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub struct CurrentSession {
114 pub session: Session,
115 pub user: User,
116}
117
118fn current_session_user_state() -> &'static RequestState<Option<Value>> {
119 CURRENT_SESSION_USER.get_or_init(|| define_request_state(|| None))
120}
121
122fn current_request_path_state() -> &'static RequestState<Option<String>> {
123 CURRENT_REQUEST_PATH.get_or_init(|| define_request_state(|| None))
124}
125
126pub fn set_current_session_user(user: Value) -> Result<(), RustAuthError> {
128 current_session_user_state().set(Some(user))
129}
130
131pub fn current_session_user() -> Result<Option<Value>, RustAuthError> {
133 current_session_user_state().get()
134}
135
136fn current_session_state() -> &'static RequestState<Option<CurrentSession>> {
137 CURRENT_SESSION.get_or_init(|| define_request_state(|| None))
138}
139
140pub fn set_current_session(session: Session, user: User) -> Result<(), RustAuthError> {
141 current_session_state().set(Some(CurrentSession { session, user }))
142}
143
144pub fn current_session() -> Result<Option<CurrentSession>, RustAuthError> {
145 current_session_state().get()
146}
147
148fn current_new_session_state() -> &'static RequestState<Option<NewSession>> {
149 CURRENT_NEW_SESSION.get_or_init(|| define_request_state(|| None))
150}
151
152pub fn set_current_new_session(session: Session, user: User) -> Result<(), RustAuthError> {
153 current_new_session_state().set(Some(NewSession { session, user }))
154}
155
156pub fn current_new_session() -> Result<Option<NewSession>, RustAuthError> {
157 current_new_session_state().get()
158}
159
160pub fn set_current_request_path(path: impl Into<String>) -> Result<(), RustAuthError> {
162 current_request_path_state().set(Some(path.into()))
163}
164
165pub fn current_request_path() -> Result<Option<String>, RustAuthError> {
167 current_request_path_state().get()
168}
169
170fn request_is_external_state() -> &'static RequestState<bool> {
171 REQUEST_IS_EXTERNAL.get_or_init(|| define_request_state(|| false))
172}
173
174fn should_skip_session_refresh_state() -> &'static RequestState<bool> {
175 SHOULD_SKIP_SESSION_REFRESH.get_or_init(|| define_request_state(|| false))
176}
177
178pub fn set_request_external(external: bool) -> Result<(), RustAuthError> {
181 request_is_external_state().set(external)
182}
183
184pub fn is_external_request() -> bool {
188 if !has_request_state() {
189 return false;
190 }
191 request_is_external_state().get().unwrap_or(false)
192}
193
194pub fn set_should_skip_session_refresh(skip: bool) -> Result<(), RustAuthError> {
196 should_skip_session_refresh_state().set(skip)
197}
198
199pub fn should_skip_session_refresh() -> bool {
201 if !has_request_state() {
202 return false;
203 }
204 should_skip_session_refresh_state().get().unwrap_or(false)
205}
206
207pub async fn run_with_request_state<F>(future: F) -> F::Output
209where
210 F: Future,
211{
212 REQUEST_STATE
213 .scope(RefCell::new(RequestStateStore::new()), future)
214 .await
215}
216
217pub fn has_request_state() -> bool {
219 REQUEST_STATE.try_with(|_| ()).is_ok()
220}
221
222fn with_current_store<T>(
223 operation: impl FnOnce(&mut RequestStateStore) -> Result<T, RustAuthError>,
224) -> Result<T, RustAuthError> {
225 REQUEST_STATE
226 .try_with(|store| operation(&mut store.borrow_mut()))
227 .map_err(|_| RustAuthError::RequestStateMissing)?
228}