Skip to main content

rustauth_core/options/
hooks.rs

1//! Global request hooks (parity with Better Auth `hooks` init option).
2
3use std::fmt;
4use std::sync::Arc;
5
6use http::Method;
7
8use crate::api::{ApiRequest, ApiResponse};
9use crate::context::AuthContext;
10use crate::error::RustAuthError;
11use crate::plugin::{PluginAfterHook, PluginBeforeHook, PluginHookMatcher};
12
13/// Global before/after hooks applied to every matched endpoint.
14#[derive(Clone, Default)]
15pub struct GlobalHooksOptions {
16    pub before: Option<Arc<dyn GlobalBeforeHook>>,
17    pub after: Option<Arc<dyn GlobalAfterHook>>,
18}
19
20impl GlobalHooksOptions {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    #[must_use]
26    pub fn before<H>(mut self, hook: H) -> Self
27    where
28        H: GlobalBeforeHook,
29    {
30        self.before = Some(Arc::new(hook));
31        self
32    }
33
34    #[must_use]
35    pub fn after<H>(mut self, hook: H) -> Self
36    where
37        H: GlobalAfterHook,
38    {
39        self.after = Some(Arc::new(hook));
40        self
41    }
42}
43
44impl fmt::Debug for GlobalHooksOptions {
45    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46        formatter
47            .debug_struct("GlobalHooksOptions")
48            .field(
49                "before",
50                &self.before.as_ref().map(|_| "<global-before-hook>"),
51            )
52            .field("after", &self.after.as_ref().map(|_| "<global-after-hook>"))
53            .finish()
54    }
55}
56
57/// Runs before any endpoint handler (after plugins' onRequest).
58pub trait GlobalBeforeHook: Send + Sync + 'static {
59    fn before(
60        &self,
61        context: &AuthContext,
62        request: ApiRequest,
63        method: &Method,
64        path: &str,
65    ) -> Result<GlobalHookAction, RustAuthError>;
66}
67
68impl<F> GlobalBeforeHook for F
69where
70    F: Fn(&AuthContext, ApiRequest, &Method, &str) -> Result<GlobalHookAction, RustAuthError>
71        + Send
72        + Sync
73        + 'static,
74{
75    fn before(
76        &self,
77        context: &AuthContext,
78        request: ApiRequest,
79        method: &Method,
80        path: &str,
81    ) -> Result<GlobalHookAction, RustAuthError> {
82        self(context, request, method, path)
83    }
84}
85
86/// Runs after any endpoint handler (before plugins' onResponse).
87pub trait GlobalAfterHook: Send + Sync + 'static {
88    fn after(
89        &self,
90        context: &AuthContext,
91        request: &ApiRequest,
92        response: ApiResponse,
93        method: &Method,
94        path: &str,
95    ) -> Result<ApiResponse, RustAuthError>;
96}
97
98impl<F> GlobalAfterHook for F
99where
100    F: Fn(
101            &AuthContext,
102            &ApiRequest,
103            ApiResponse,
104            &Method,
105            &str,
106        ) -> Result<ApiResponse, RustAuthError>
107        + Send
108        + Sync
109        + 'static,
110{
111    fn after(
112        &self,
113        context: &AuthContext,
114        request: &ApiRequest,
115        response: ApiResponse,
116        method: &Method,
117        path: &str,
118    ) -> Result<ApiResponse, RustAuthError> {
119        self(context, request, response, method, path)
120    }
121}
122
123/// Action returned by a global before hook.
124pub enum GlobalHookAction {
125    Continue(ApiRequest),
126    Respond(ApiResponse),
127}
128
129pub(crate) fn plugin_before_hooks(options: &GlobalHooksOptions) -> Vec<PluginBeforeHook> {
130    let Some(hook) = options.before.clone() else {
131        return Vec::new();
132    };
133    vec![PluginBeforeHook {
134        matcher: PluginHookMatcher {
135            path: "/*".to_owned(),
136            method: None,
137            operation_id: None,
138        },
139        handler: Arc::new(move |context, request| {
140            let method = request.method().clone();
141            let path = request
142                .uri()
143                .path()
144                .trim_start_matches(context.base_path.trim_end_matches('/'))
145                .to_owned();
146            match hook.before(context, request, &method, &path)? {
147                GlobalHookAction::Continue(request) => {
148                    Ok(crate::plugin::PluginBeforeHookAction::Continue(request))
149                }
150                GlobalHookAction::Respond(response) => {
151                    Ok(crate::plugin::PluginBeforeHookAction::Respond(response))
152                }
153            }
154        }),
155    }]
156}
157
158pub(crate) fn plugin_after_hooks(options: &GlobalHooksOptions) -> Vec<PluginAfterHook> {
159    let Some(hook) = options.after.clone() else {
160        return Vec::new();
161    };
162    vec![PluginAfterHook {
163        matcher: PluginHookMatcher {
164            path: "/*".to_owned(),
165            method: None,
166            operation_id: None,
167        },
168        handler: Arc::new(move |context, request, response| {
169            let method = request.method().clone();
170            let path = request
171                .uri()
172                .path()
173                .trim_start_matches(context.base_path.trim_end_matches('/'))
174                .to_owned();
175            let response = hook.after(context, request, response, &method, &path)?;
176            Ok(crate::plugin::PluginAfterHookAction::Continue(response))
177        }),
178    }]
179}
180
181#[cfg(test)]
182mod tests {
183    use http::Method;
184
185    use crate::api::{ApiRequest, ApiResponse};
186    use crate::context::AuthContext;
187
188    use super::*;
189
190    struct TestBeforeHook;
191    struct TestAfterHook;
192
193    impl GlobalBeforeHook for TestBeforeHook {
194        fn before(
195            &self,
196            _context: &AuthContext,
197            request: ApiRequest,
198            _method: &Method,
199            _path: &str,
200        ) -> Result<GlobalHookAction, RustAuthError> {
201            Ok(GlobalHookAction::Continue(request))
202        }
203    }
204
205    impl GlobalAfterHook for TestAfterHook {
206        fn after(
207            &self,
208            _context: &AuthContext,
209            _request: &ApiRequest,
210            response: ApiResponse,
211            _method: &Method,
212            _path: &str,
213        ) -> Result<ApiResponse, RustAuthError> {
214            Ok(response)
215        }
216    }
217
218    #[test]
219    fn global_hooks_options_supports_fluent_registration() {
220        let options = GlobalHooksOptions::new()
221            .before(TestBeforeHook)
222            .after(TestAfterHook);
223
224        assert!(options.before.is_some());
225        assert!(options.after.is_some());
226    }
227}