Skip to main content

vox_types/
server_middleware.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex},
7};
8
9use crate::{ReplySink, RequestContext, RequestResponse};
10
11/// Per-request type-indexed storage shared across middleware hooks and handlers.
12#[derive(Debug, Default)]
13pub struct Extensions {
14    inner: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
15}
16
17impl Extensions {
18    /// Create a new empty extensions bag.
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    /// Insert a typed value into the bag, returning the previous value of the same type.
24    pub fn insert<T>(&self, value: T) -> Option<T>
25    where
26        T: Send + Sync + 'static,
27    {
28        let previous = self
29            .inner
30            .lock()
31            .expect("extensions mutex poisoned")
32            .insert(TypeId::of::<T>(), Box::new(value));
33        previous
34            .map(|boxed| {
35                boxed
36                    .downcast::<T>()
37                    .expect("extensions type id and boxed value disagreed")
38            })
39            .map(|boxed| *boxed)
40    }
41
42    /// Returns `true` if a value of type `T` is present.
43    pub fn contains<T>(&self) -> bool
44    where
45        T: Send + Sync + 'static,
46    {
47        self.inner
48            .lock()
49            .expect("extensions mutex poisoned")
50            .contains_key(&TypeId::of::<T>())
51    }
52
53    /// Borrow a typed value from the bag for the duration of `f`.
54    pub fn with<T, R>(&self, f: impl FnOnce(&T) -> R) -> Option<R>
55    where
56        T: Send + Sync + 'static,
57    {
58        let guard = self.inner.lock().expect("extensions mutex poisoned");
59        let value = guard.get(&TypeId::of::<T>())?;
60        let value = value
61            .downcast_ref::<T>()
62            .expect("extensions type id and boxed value disagreed");
63        Some(f(value))
64    }
65
66    /// Mutably borrow a typed value from the bag for the duration of `f`.
67    pub fn with_mut<T, R>(&self, f: impl FnOnce(&mut T) -> R) -> Option<R>
68    where
69        T: Send + Sync + 'static,
70    {
71        let mut guard = self.inner.lock().expect("extensions mutex poisoned");
72        let value = guard.get_mut(&TypeId::of::<T>())?;
73        let value = value
74            .downcast_mut::<T>()
75            .expect("extensions type id and boxed value disagreed");
76        Some(f(value))
77    }
78
79    /// Clone a typed value from the bag.
80    pub fn get_cloned<T>(&self) -> Option<T>
81    where
82        T: Clone + Send + Sync + 'static,
83    {
84        self.with(|value: &T| value.clone())
85    }
86}
87
88#[cfg(not(target_arch = "wasm32"))]
89pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
90#[cfg(target_arch = "wasm32")]
91pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
92
93/// Outcome observed by server middleware after handler dispatch.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum ServerCallOutcome {
96    /// The handler sent a reply through the reply sink.
97    Replied,
98    /// The handler returned without replying; the runtime will synthesize cancellation.
99    DroppedWithoutReply,
100}
101
102impl ServerCallOutcome {
103    pub fn replied(self) -> bool {
104        matches!(self, Self::Replied)
105    }
106}
107
108/// Observe inbound server requests before and after dispatch.
109pub trait ServerMiddleware: Send + Sync + 'static {
110    fn pre<'a>(&'a self, _context: &'a RequestContext<'a>) -> BoxMiddlewareFuture<'a> {
111        Box::pin(async {})
112    }
113
114    fn post<'a>(
115        &'a self,
116        _context: &'a RequestContext<'a>,
117        _outcome: ServerCallOutcome,
118    ) -> BoxMiddlewareFuture<'a> {
119        Box::pin(async {})
120    }
121}
122
123#[derive(Clone)]
124#[doc(hidden)]
125pub struct ServerCallOutcomeHandle {
126    outcome: Arc<Mutex<ServerCallOutcome>>,
127}
128
129impl ServerCallOutcomeHandle {
130    pub fn outcome(&self) -> ServerCallOutcome {
131        *self
132            .outcome
133            .lock()
134            .expect("server call outcome mutex poisoned")
135    }
136}
137
138#[doc(hidden)]
139pub struct ObservedReplySink<R> {
140    inner: Option<R>,
141    outcome: ServerCallOutcomeHandle,
142}
143
144#[doc(hidden)]
145pub fn observe_reply<R>(reply: R) -> (ObservedReplySink<R>, ServerCallOutcomeHandle) {
146    let outcome = ServerCallOutcomeHandle {
147        outcome: Arc::new(Mutex::new(ServerCallOutcome::DroppedWithoutReply)),
148    };
149    (
150        ObservedReplySink {
151            inner: Some(reply),
152            outcome: outcome.clone(),
153        },
154        outcome,
155    )
156}
157
158impl<R> ReplySink for ObservedReplySink<R>
159where
160    R: ReplySink,
161{
162    async fn send_reply(mut self, response: RequestResponse<'_>) {
163        *self
164            .outcome
165            .outcome
166            .lock()
167            .expect("server call outcome mutex poisoned") = ServerCallOutcome::Replied;
168        let reply = self
169            .inner
170            .take()
171            .expect("observed reply sink can only reply once");
172        reply.send_reply(response).await;
173    }
174
175    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
176        self.inner.as_ref().and_then(|reply| reply.channel_binder())
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::{Extensions, ServerCallOutcome};
183
184    #[test]
185    fn extensions_store_values_by_type() {
186        let extensions = Extensions::new();
187        assert!(!extensions.contains::<u32>());
188        assert_eq!(extensions.insert(41_u32), None);
189        assert!(extensions.contains::<u32>());
190        assert_eq!(extensions.get_cloned::<u32>(), Some(41));
191        let updated = extensions.with_mut::<u32, _>(|value| {
192            *value += 1;
193            *value
194        });
195        assert_eq!(updated, Some(42));
196        assert_eq!(extensions.get_cloned::<u32>(), Some(42));
197    }
198
199    #[test]
200    fn server_call_outcome_reports_reply_state() {
201        assert!(ServerCallOutcome::Replied.replied());
202        assert!(!ServerCallOutcome::DroppedWithoutReply.replied());
203    }
204}