vox_types/
server_middleware.rs1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 future::Future,
5 pin::Pin,
6 sync::{Arc, Mutex},
7};
8
9use crate::{ReplySink, RequestContext, RequestResponse};
10
11#[derive(Debug, Default)]
13pub struct Extensions {
14 inner: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
15}
16
17impl Extensions {
18 pub fn new() -> Self {
20 Self::default()
21 }
22
23 pub fn insert<T>(&self, value: T) -> Option<T>
25 where
26 T: Send + Sync + 'static,
27 {
28 let previous = self
29 .inner
30 .lock()
31 .expect("extensions mutex poisoned")
32 .insert(TypeId::of::<T>(), Box::new(value));
33 previous
34 .map(|boxed| {
35 boxed
36 .downcast::<T>()
37 .expect("extensions type id and boxed value disagreed")
38 })
39 .map(|boxed| *boxed)
40 }
41
42 pub fn contains<T>(&self) -> bool
44 where
45 T: Send + Sync + 'static,
46 {
47 self.inner
48 .lock()
49 .expect("extensions mutex poisoned")
50 .contains_key(&TypeId::of::<T>())
51 }
52
53 pub fn with<T, R>(&self, f: impl FnOnce(&T) -> R) -> Option<R>
55 where
56 T: Send + Sync + 'static,
57 {
58 let guard = self.inner.lock().expect("extensions mutex poisoned");
59 let value = guard.get(&TypeId::of::<T>())?;
60 let value = value
61 .downcast_ref::<T>()
62 .expect("extensions type id and boxed value disagreed");
63 Some(f(value))
64 }
65
66 pub fn with_mut<T, R>(&self, f: impl FnOnce(&mut T) -> R) -> Option<R>
68 where
69 T: Send + Sync + 'static,
70 {
71 let mut guard = self.inner.lock().expect("extensions mutex poisoned");
72 let value = guard.get_mut(&TypeId::of::<T>())?;
73 let value = value
74 .downcast_mut::<T>()
75 .expect("extensions type id and boxed value disagreed");
76 Some(f(value))
77 }
78
79 pub fn get_cloned<T>(&self) -> Option<T>
81 where
82 T: Clone + Send + Sync + 'static,
83 {
84 self.with(|value: &T| value.clone())
85 }
86}
87
88#[cfg(not(target_arch = "wasm32"))]
89pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
90#[cfg(target_arch = "wasm32")]
91pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum ServerCallOutcome {
96 Replied,
98 DroppedWithoutReply,
100}
101
102impl ServerCallOutcome {
103 pub fn replied(self) -> bool {
104 matches!(self, Self::Replied)
105 }
106}
107
108pub trait ServerMiddleware: Send + Sync + 'static {
110 fn pre<'a>(&'a self, _context: &'a RequestContext<'a>) -> BoxMiddlewareFuture<'a> {
111 Box::pin(async {})
112 }
113
114 fn post<'a>(
115 &'a self,
116 _context: &'a RequestContext<'a>,
117 _outcome: ServerCallOutcome,
118 ) -> BoxMiddlewareFuture<'a> {
119 Box::pin(async {})
120 }
121}
122
123#[derive(Clone)]
124#[doc(hidden)]
125pub struct ServerCallOutcomeHandle {
126 outcome: Arc<Mutex<ServerCallOutcome>>,
127}
128
129impl ServerCallOutcomeHandle {
130 pub fn outcome(&self) -> ServerCallOutcome {
131 *self
132 .outcome
133 .lock()
134 .expect("server call outcome mutex poisoned")
135 }
136}
137
138#[doc(hidden)]
139pub struct ObservedReplySink<R> {
140 inner: Option<R>,
141 outcome: ServerCallOutcomeHandle,
142}
143
144#[doc(hidden)]
145pub fn observe_reply<R>(reply: R) -> (ObservedReplySink<R>, ServerCallOutcomeHandle) {
146 let outcome = ServerCallOutcomeHandle {
147 outcome: Arc::new(Mutex::new(ServerCallOutcome::DroppedWithoutReply)),
148 };
149 (
150 ObservedReplySink {
151 inner: Some(reply),
152 outcome: outcome.clone(),
153 },
154 outcome,
155 )
156}
157
158impl<R> ReplySink for ObservedReplySink<R>
159where
160 R: ReplySink,
161{
162 async fn send_reply(mut self, response: RequestResponse<'_>) {
163 *self
164 .outcome
165 .outcome
166 .lock()
167 .expect("server call outcome mutex poisoned") = ServerCallOutcome::Replied;
168 let reply = self
169 .inner
170 .take()
171 .expect("observed reply sink can only reply once");
172 reply.send_reply(response).await;
173 }
174
175 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
176 self.inner.as_ref().and_then(|reply| reply.channel_binder())
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::{Extensions, ServerCallOutcome};
183
184 #[test]
185 fn extensions_store_values_by_type() {
186 let extensions = Extensions::new();
187 assert!(!extensions.contains::<u32>());
188 assert_eq!(extensions.insert(41_u32), None);
189 assert!(extensions.contains::<u32>());
190 assert_eq!(extensions.get_cloned::<u32>(), Some(41));
191 let updated = extensions.with_mut::<u32, _>(|value| {
192 *value += 1;
193 *value
194 });
195 assert_eq!(updated, Some(42));
196 assert_eq!(extensions.get_cloned::<u32>(), Some(42));
197 }
198
199 #[test]
200 fn server_call_outcome_reports_reply_state() {
201 assert!(ServerCallOutcome::Replied.replied());
202 assert!(!ServerCallOutcome::DroppedWithoutReply.replied());
203 }
204}