rustauth_core/plugin/
hooks.rs1use crate::api::{ApiRequest, ApiResponse};
4use crate::context::AuthContext;
5use crate::error::RustAuthError;
6use http::Method;
7use std::fmt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub type PluginBeforeHookHandler = Arc<
13 dyn Fn(&AuthContext, ApiRequest) -> Result<PluginBeforeHookAction, RustAuthError> + Send + Sync,
14>;
15pub type PluginAfterHookHandler = Arc<
16 dyn Fn(&AuthContext, &ApiRequest, ApiResponse) -> Result<PluginAfterHookAction, RustAuthError>
17 + Send
18 + Sync,
19>;
20pub type PluginBeforeHookFuture<'a> =
21 Pin<Box<dyn Future<Output = Result<PluginBeforeHookAction, RustAuthError>> + Send + 'a>>;
22pub type PluginAfterHookFuture<'a> =
23 Pin<Box<dyn Future<Output = Result<PluginAfterHookAction, RustAuthError>> + Send + 'a>>;
24pub type PluginAsyncBeforeHookHandler =
25 Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> PluginBeforeHookFuture<'a> + Send + Sync>;
26pub type PluginAsyncAfterHookHandler = Arc<
27 dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest, ApiResponse) -> PluginAfterHookFuture<'a>
28 + Send
29 + Sync,
30>;
31
32pub fn async_after_hook_handler<F, Fut>(
34 handler: F,
35) -> impl for<'a> Fn(&'a AuthContext, &'a ApiRequest, ApiResponse) -> PluginAfterHookFuture<'a>
36 + Send
37 + Sync
38 + Clone
39 + 'static
40where
41 F: for<'a> Fn(AuthContext, &'a ApiRequest, ApiResponse) -> Fut + Send + Sync + Clone + 'static,
42 for<'a> Fut: Future<Output = Result<PluginAfterHookAction, RustAuthError>> + Send + 'a,
43{
44 move |context: &AuthContext, request: &ApiRequest, response: ApiResponse| {
45 let handler = handler.clone();
46 let context = context.clone();
47 Box::pin(handler(context, request, response))
48 }
49}
50
51pub fn async_before_hook_handler<F, Fut>(
53 handler: F,
54) -> impl for<'a> Fn(&'a AuthContext, ApiRequest) -> PluginBeforeHookFuture<'a>
55 + Send
56 + Sync
57 + Clone
58 + 'static
59where
60 F: Fn(AuthContext, ApiRequest) -> Fut + Send + Sync + Clone + 'static,
61 Fut: Future<Output = Result<PluginBeforeHookAction, RustAuthError>> + Send + 'static,
62{
63 move |context: &AuthContext, request: ApiRequest| {
64 let handler = handler.clone();
65 let context = context.clone();
66 Box::pin(handler(context, request))
67 }
68}
69
70pub enum PluginBeforeHookAction {
72 Continue(ApiRequest),
73 Respond(ApiResponse),
74}
75
76pub enum PluginAfterHookAction {
78 Continue(ApiResponse),
79}
80
81#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct PluginHookMatcher {
84 pub path: String,
85 pub method: Option<Method>,
86 pub operation_id: Option<String>,
87}
88
89impl PluginHookMatcher {
90 pub fn path(path: impl Into<String>) -> Self {
91 Self {
92 path: path.into(),
93 method: None,
94 operation_id: None,
95 }
96 }
97
98 #[must_use]
99 pub fn method(mut self, method: Method) -> Self {
100 self.method = Some(method);
101 self
102 }
103
104 #[must_use]
105 pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
106 self.operation_id = Some(operation_id.into());
107 self
108 }
109
110 pub fn matches(&self, method: &Method, path: &str, operation_id: Option<&str>) -> bool {
111 if self
112 .method
113 .as_ref()
114 .is_some_and(|expected| expected != method)
115 {
116 return false;
117 }
118 if self
119 .operation_id
120 .as_deref()
121 .is_some_and(|expected| Some(expected) != operation_id)
122 {
123 return false;
124 }
125 path_matches(&self.path, path)
126 }
127}
128
129#[derive(Clone)]
130pub struct PluginBeforeHook {
131 pub matcher: PluginHookMatcher,
132 pub handler: PluginBeforeHookHandler,
133}
134
135impl fmt::Debug for PluginBeforeHook {
136 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
137 formatter
138 .debug_struct("PluginBeforeHook")
139 .field("matcher", &self.matcher)
140 .field("handler", &"<before-hook>")
141 .finish()
142 }
143}
144
145#[derive(Clone)]
146pub struct PluginAfterHook {
147 pub matcher: PluginHookMatcher,
148 pub handler: PluginAfterHookHandler,
149}
150
151impl fmt::Debug for PluginAfterHook {
152 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
153 formatter
154 .debug_struct("PluginAfterHook")
155 .field("matcher", &self.matcher)
156 .field("handler", &"<after-hook>")
157 .finish()
158 }
159}
160
161#[derive(Clone)]
162pub struct PluginAsyncBeforeHook {
163 pub matcher: PluginHookMatcher,
164 pub handler: PluginAsyncBeforeHookHandler,
165}
166
167impl fmt::Debug for PluginAsyncBeforeHook {
168 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
169 formatter
170 .debug_struct("PluginAsyncBeforeHook")
171 .field("matcher", &self.matcher)
172 .field("handler", &"<async-before-hook>")
173 .finish()
174 }
175}
176
177#[derive(Clone)]
178pub struct PluginAsyncAfterHook {
179 pub matcher: PluginHookMatcher,
180 pub handler: PluginAsyncAfterHookHandler,
181}
182
183impl fmt::Debug for PluginAsyncAfterHook {
184 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
185 formatter
186 .debug_struct("PluginAsyncAfterHook")
187 .field("matcher", &self.matcher)
188 .field("handler", &"<async-after-hook>")
189 .finish()
190 }
191}
192
193#[derive(Debug, Clone, Default)]
194pub struct PluginEndpointHooks {
195 pub before: Vec<PluginBeforeHook>,
196 pub after: Vec<PluginAfterHook>,
197 pub async_before: Vec<PluginAsyncBeforeHook>,
198 pub async_after: Vec<PluginAsyncAfterHook>,
199}
200
201fn path_matches(pattern: &str, path: &str) -> bool {
202 if let Some((prefix, suffix)) = pattern.split_once('*') {
203 return path.starts_with(prefix) && path.ends_with(suffix);
204 }
205 let pattern_segments = pattern.trim_matches('/').split('/').collect::<Vec<_>>();
206 let path_segments = path.trim_matches('/').split('/').collect::<Vec<_>>();
207 if pattern_segments.len() != path_segments.len() {
208 return false;
209 }
210 pattern_segments
211 .iter()
212 .zip(path_segments.iter())
213 .all(|(expected, actual)| expected.starts_with(':') || expected == actual)
214}