Skip to main content

rustauth_core/plugin/
hooks.rs

1//! Endpoint-scoped plugin hooks.
2
3use crate::api::{ApiRequest, ApiResponse};
4use crate::context::AuthContext;
5use crate::error::RustAuthError;
6use http::Method;
7use std::fmt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub type PluginBeforeHookHandler = Arc<
13    dyn Fn(&AuthContext, ApiRequest) -> Result<PluginBeforeHookAction, RustAuthError> + Send + Sync,
14>;
15pub type PluginAfterHookHandler = Arc<
16    dyn Fn(&AuthContext, &ApiRequest, ApiResponse) -> Result<PluginAfterHookAction, RustAuthError>
17        + Send
18        + Sync,
19>;
20pub type PluginBeforeHookFuture<'a> =
21    Pin<Box<dyn Future<Output = Result<PluginBeforeHookAction, RustAuthError>> + Send + 'a>>;
22pub type PluginAfterHookFuture<'a> =
23    Pin<Box<dyn Future<Output = Result<PluginAfterHookAction, RustAuthError>> + Send + 'a>>;
24pub type PluginAsyncBeforeHookHandler =
25    Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> PluginBeforeHookFuture<'a> + Send + Sync>;
26pub type PluginAsyncAfterHookHandler = Arc<
27    dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest, ApiResponse) -> PluginAfterHookFuture<'a>
28        + Send
29        + Sync,
30>;
31
32/// Wraps an async after-hook handler so plugin authors do not need `Box::pin`.
33pub fn async_after_hook_handler<F, Fut>(
34    handler: F,
35) -> impl for<'a> Fn(&'a AuthContext, &'a ApiRequest, ApiResponse) -> PluginAfterHookFuture<'a>
36       + Send
37       + Sync
38       + Clone
39       + 'static
40where
41    F: for<'a> Fn(AuthContext, &'a ApiRequest, ApiResponse) -> Fut + Send + Sync + Clone + 'static,
42    for<'a> Fut: Future<Output = Result<PluginAfterHookAction, RustAuthError>> + Send + 'a,
43{
44    move |context: &AuthContext, request: &ApiRequest, response: ApiResponse| {
45        let handler = handler.clone();
46        let context = context.clone();
47        Box::pin(handler(context, request, response))
48    }
49}
50
51/// Wraps an async before-hook handler so plugin authors do not need `Box::pin`.
52pub fn async_before_hook_handler<F, Fut>(
53    handler: F,
54) -> impl for<'a> Fn(&'a AuthContext, ApiRequest) -> PluginBeforeHookFuture<'a>
55       + Send
56       + Sync
57       + Clone
58       + 'static
59where
60    F: Fn(AuthContext, ApiRequest) -> Fut + Send + Sync + Clone + 'static,
61    Fut: Future<Output = Result<PluginBeforeHookAction, RustAuthError>> + Send + 'static,
62{
63    move |context: &AuthContext, request: ApiRequest| {
64        let handler = handler.clone();
65        let context = context.clone();
66        Box::pin(handler(context, request))
67    }
68}
69
70/// Action returned by a before endpoint hook.
71pub enum PluginBeforeHookAction {
72    Continue(ApiRequest),
73    Respond(ApiResponse),
74}
75
76/// Action returned by an after endpoint hook.
77pub enum PluginAfterHookAction {
78    Continue(ApiResponse),
79}
80
81/// Matcher used to select endpoint hooks.
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct PluginHookMatcher {
84    pub path: String,
85    pub method: Option<Method>,
86    pub operation_id: Option<String>,
87}
88
89impl PluginHookMatcher {
90    pub fn path(path: impl Into<String>) -> Self {
91        Self {
92            path: path.into(),
93            method: None,
94            operation_id: None,
95        }
96    }
97
98    #[must_use]
99    pub fn method(mut self, method: Method) -> Self {
100        self.method = Some(method);
101        self
102    }
103
104    #[must_use]
105    pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
106        self.operation_id = Some(operation_id.into());
107        self
108    }
109
110    pub fn matches(&self, method: &Method, path: &str, operation_id: Option<&str>) -> bool {
111        if self
112            .method
113            .as_ref()
114            .is_some_and(|expected| expected != method)
115        {
116            return false;
117        }
118        if self
119            .operation_id
120            .as_deref()
121            .is_some_and(|expected| Some(expected) != operation_id)
122        {
123            return false;
124        }
125        path_matches(&self.path, path)
126    }
127}
128
129#[derive(Clone)]
130pub struct PluginBeforeHook {
131    pub matcher: PluginHookMatcher,
132    pub handler: PluginBeforeHookHandler,
133}
134
135impl fmt::Debug for PluginBeforeHook {
136    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
137        formatter
138            .debug_struct("PluginBeforeHook")
139            .field("matcher", &self.matcher)
140            .field("handler", &"<before-hook>")
141            .finish()
142    }
143}
144
145#[derive(Clone)]
146pub struct PluginAfterHook {
147    pub matcher: PluginHookMatcher,
148    pub handler: PluginAfterHookHandler,
149}
150
151impl fmt::Debug for PluginAfterHook {
152    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
153        formatter
154            .debug_struct("PluginAfterHook")
155            .field("matcher", &self.matcher)
156            .field("handler", &"<after-hook>")
157            .finish()
158    }
159}
160
161#[derive(Clone)]
162pub struct PluginAsyncBeforeHook {
163    pub matcher: PluginHookMatcher,
164    pub handler: PluginAsyncBeforeHookHandler,
165}
166
167impl fmt::Debug for PluginAsyncBeforeHook {
168    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
169        formatter
170            .debug_struct("PluginAsyncBeforeHook")
171            .field("matcher", &self.matcher)
172            .field("handler", &"<async-before-hook>")
173            .finish()
174    }
175}
176
177#[derive(Clone)]
178pub struct PluginAsyncAfterHook {
179    pub matcher: PluginHookMatcher,
180    pub handler: PluginAsyncAfterHookHandler,
181}
182
183impl fmt::Debug for PluginAsyncAfterHook {
184    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
185        formatter
186            .debug_struct("PluginAsyncAfterHook")
187            .field("matcher", &self.matcher)
188            .field("handler", &"<async-after-hook>")
189            .finish()
190    }
191}
192
193#[derive(Debug, Clone, Default)]
194pub struct PluginEndpointHooks {
195    pub before: Vec<PluginBeforeHook>,
196    pub after: Vec<PluginAfterHook>,
197    pub async_before: Vec<PluginAsyncBeforeHook>,
198    pub async_after: Vec<PluginAsyncAfterHook>,
199}
200
201fn path_matches(pattern: &str, path: &str) -> bool {
202    if let Some((prefix, suffix)) = pattern.split_once('*') {
203        return path.starts_with(prefix) && path.ends_with(suffix);
204    }
205    let pattern_segments = pattern.trim_matches('/').split('/').collect::<Vec<_>>();
206    let path_segments = path.trim_matches('/').split('/').collect::<Vec<_>>();
207    if pattern_segments.len() != path_segments.len() {
208        return false;
209    }
210    pattern_segments
211        .iter()
212        .zip(path_segments.iter())
213        .all(|(expected, actual)| expected.starts_with(':') || expected == actual)
214}