1use std::fmt;
4use std::sync::Arc;
5
6use http::Method;
7
8use crate::api::{ApiRequest, ApiResponse};
9use crate::context::AuthContext;
10use crate::error::RustAuthError;
11use crate::plugin::{PluginAfterHook, PluginBeforeHook, PluginHookMatcher};
12
13#[derive(Clone, Default)]
15pub struct GlobalHooksOptions {
16 pub before: Option<Arc<dyn GlobalBeforeHook>>,
17 pub after: Option<Arc<dyn GlobalAfterHook>>,
18}
19
20impl GlobalHooksOptions {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 #[must_use]
26 pub fn before<H>(mut self, hook: H) -> Self
27 where
28 H: GlobalBeforeHook,
29 {
30 self.before = Some(Arc::new(hook));
31 self
32 }
33
34 #[must_use]
35 pub fn after<H>(mut self, hook: H) -> Self
36 where
37 H: GlobalAfterHook,
38 {
39 self.after = Some(Arc::new(hook));
40 self
41 }
42}
43
44impl fmt::Debug for GlobalHooksOptions {
45 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46 formatter
47 .debug_struct("GlobalHooksOptions")
48 .field(
49 "before",
50 &self.before.as_ref().map(|_| "<global-before-hook>"),
51 )
52 .field("after", &self.after.as_ref().map(|_| "<global-after-hook>"))
53 .finish()
54 }
55}
56
57pub trait GlobalBeforeHook: Send + Sync + 'static {
59 fn before(
60 &self,
61 context: &AuthContext,
62 request: ApiRequest,
63 method: &Method,
64 path: &str,
65 ) -> Result<GlobalHookAction, RustAuthError>;
66}
67
68impl<F> GlobalBeforeHook for F
69where
70 F: Fn(&AuthContext, ApiRequest, &Method, &str) -> Result<GlobalHookAction, RustAuthError>
71 + Send
72 + Sync
73 + 'static,
74{
75 fn before(
76 &self,
77 context: &AuthContext,
78 request: ApiRequest,
79 method: &Method,
80 path: &str,
81 ) -> Result<GlobalHookAction, RustAuthError> {
82 self(context, request, method, path)
83 }
84}
85
86pub trait GlobalAfterHook: Send + Sync + 'static {
88 fn after(
89 &self,
90 context: &AuthContext,
91 request: &ApiRequest,
92 response: ApiResponse,
93 method: &Method,
94 path: &str,
95 ) -> Result<ApiResponse, RustAuthError>;
96}
97
98impl<F> GlobalAfterHook for F
99where
100 F: Fn(
101 &AuthContext,
102 &ApiRequest,
103 ApiResponse,
104 &Method,
105 &str,
106 ) -> Result<ApiResponse, RustAuthError>
107 + Send
108 + Sync
109 + 'static,
110{
111 fn after(
112 &self,
113 context: &AuthContext,
114 request: &ApiRequest,
115 response: ApiResponse,
116 method: &Method,
117 path: &str,
118 ) -> Result<ApiResponse, RustAuthError> {
119 self(context, request, response, method, path)
120 }
121}
122
123pub enum GlobalHookAction {
125 Continue(ApiRequest),
126 Respond(ApiResponse),
127}
128
129pub(crate) fn plugin_before_hooks(options: &GlobalHooksOptions) -> Vec<PluginBeforeHook> {
130 let Some(hook) = options.before.clone() else {
131 return Vec::new();
132 };
133 vec![PluginBeforeHook {
134 matcher: PluginHookMatcher {
135 path: "/*".to_owned(),
136 method: None,
137 operation_id: None,
138 },
139 handler: Arc::new(move |context, request| {
140 let method = request.method().clone();
141 let path = request
142 .uri()
143 .path()
144 .trim_start_matches(context.base_path.trim_end_matches('/'))
145 .to_owned();
146 match hook.before(context, request, &method, &path)? {
147 GlobalHookAction::Continue(request) => {
148 Ok(crate::plugin::PluginBeforeHookAction::Continue(request))
149 }
150 GlobalHookAction::Respond(response) => {
151 Ok(crate::plugin::PluginBeforeHookAction::Respond(response))
152 }
153 }
154 }),
155 }]
156}
157
158pub(crate) fn plugin_after_hooks(options: &GlobalHooksOptions) -> Vec<PluginAfterHook> {
159 let Some(hook) = options.after.clone() else {
160 return Vec::new();
161 };
162 vec![PluginAfterHook {
163 matcher: PluginHookMatcher {
164 path: "/*".to_owned(),
165 method: None,
166 operation_id: None,
167 },
168 handler: Arc::new(move |context, request, response| {
169 let method = request.method().clone();
170 let path = request
171 .uri()
172 .path()
173 .trim_start_matches(context.base_path.trim_end_matches('/'))
174 .to_owned();
175 let response = hook.after(context, request, response, &method, &path)?;
176 Ok(crate::plugin::PluginAfterHookAction::Continue(response))
177 }),
178 }]
179}
180
181#[cfg(test)]
182mod tests {
183 use http::Method;
184
185 use crate::api::{ApiRequest, ApiResponse};
186 use crate::context::AuthContext;
187
188 use super::*;
189
190 struct TestBeforeHook;
191 struct TestAfterHook;
192
193 impl GlobalBeforeHook for TestBeforeHook {
194 fn before(
195 &self,
196 _context: &AuthContext,
197 request: ApiRequest,
198 _method: &Method,
199 _path: &str,
200 ) -> Result<GlobalHookAction, RustAuthError> {
201 Ok(GlobalHookAction::Continue(request))
202 }
203 }
204
205 impl GlobalAfterHook for TestAfterHook {
206 fn after(
207 &self,
208 _context: &AuthContext,
209 _request: &ApiRequest,
210 response: ApiResponse,
211 _method: &Method,
212 _path: &str,
213 ) -> Result<ApiResponse, RustAuthError> {
214 Ok(response)
215 }
216 }
217
218 #[test]
219 fn global_hooks_options_supports_fluent_registration() {
220 let options = GlobalHooksOptions::new()
221 .before(TestBeforeHook)
222 .after(TestAfterHook);
223
224 assert!(options.before.is_some());
225 assert!(options.after.is_some());
226 }
227}