Skip to main content

vox_types/
server_middleware.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex},
7};
8
9use facet_reflect::Peek;
10
11use crate::{
12    ConnectionId, MetadataEntry, MethodDescriptor, Payload, ReplySink, RequestContext, RequestId,
13    RequestResponse,
14};
15
16/// Per-request type-indexed storage shared across middleware hooks and handlers.
17#[derive(Clone, Debug, Default)]
18pub struct Extensions {
19    inner: Arc<Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
20}
21
22impl Extensions {
23    /// Create a new empty extensions bag.
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    /// Insert a typed value into the bag, returning the previous value of the same type.
29    pub fn insert<T>(&self, value: T) -> Option<T>
30    where
31        T: Send + Sync + 'static,
32    {
33        let previous = self
34            .inner
35            .lock()
36            .expect("extensions mutex poisoned")
37            .insert(TypeId::of::<T>(), Box::new(value));
38        previous
39            .map(|boxed| {
40                boxed
41                    .downcast::<T>()
42                    .expect("extensions type id and boxed value disagreed")
43            })
44            .map(|boxed| *boxed)
45    }
46
47    /// Returns `true` if a value of type `T` is present.
48    pub fn contains<T>(&self) -> bool
49    where
50        T: Send + Sync + 'static,
51    {
52        self.inner
53            .lock()
54            .expect("extensions mutex poisoned")
55            .contains_key(&TypeId::of::<T>())
56    }
57
58    /// Borrow a typed value from the bag for the duration of `f`.
59    pub fn with<T, R>(&self, f: impl FnOnce(&T) -> R) -> Option<R>
60    where
61        T: Send + Sync + 'static,
62    {
63        let guard = self.inner.lock().expect("extensions mutex poisoned");
64        let value = guard.get(&TypeId::of::<T>())?;
65        let value = value
66            .downcast_ref::<T>()
67            .expect("extensions type id and boxed value disagreed");
68        Some(f(value))
69    }
70
71    /// Mutably borrow a typed value from the bag for the duration of `f`.
72    pub fn with_mut<T, R>(&self, f: impl FnOnce(&mut T) -> R) -> Option<R>
73    where
74        T: Send + Sync + 'static,
75    {
76        let mut guard = self.inner.lock().expect("extensions mutex poisoned");
77        let value = guard.get_mut(&TypeId::of::<T>())?;
78        let value = value
79            .downcast_mut::<T>()
80            .expect("extensions type id and boxed value disagreed");
81        Some(f(value))
82    }
83
84    /// Clone a typed value from the bag.
85    pub fn get_cloned<T>(&self) -> Option<T>
86    where
87        T: Clone + Send + Sync + 'static,
88    {
89        self.with(|value: &T| value.clone())
90    }
91}
92
93#[cfg(not(target_arch = "wasm32"))]
94pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
95#[cfg(target_arch = "wasm32")]
96pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
97
98/// Outcome observed by server middleware after handler dispatch.
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum ServerCallOutcome {
101    /// The handler sent a reply through the reply sink.
102    Replied,
103    /// The handler returned without replying; the runtime will synthesize cancellation.
104    DroppedWithoutReply,
105}
106
107impl ServerCallOutcome {
108    pub fn replied(self) -> bool {
109        matches!(self, Self::Replied)
110    }
111}
112
113/// Middleware-facing view of one decoded server request.
114///
115/// This is built by generated dispatchers after the inbound payload has been
116/// deserialized into the method's typed argument tuple. The tuple is then
117/// exposed reflectively through [`Peek`], allowing middleware to inspect the
118/// decoded request without re-deserializing it.
119///
120/// Because this is a borrowed reflective view, middleware should extract any
121/// owned data it needs before awaiting. The view itself is intended for
122/// synchronous inspection within the hook body.
123#[derive(Clone, Copy, Debug)]
124pub struct ServerRequest<'a> {
125    context: RequestContext<'a>,
126    args: Peek<'a, 'static>,
127}
128
129impl<'a> ServerRequest<'a> {
130    /// Create a new middleware request view from a request context and decoded args.
131    pub const fn new(context: RequestContext<'a>, args: Peek<'a, 'static>) -> Self {
132        Self { context, args }
133    }
134
135    /// Borrowed per-request context for this call.
136    pub const fn context(&self) -> &RequestContext<'a> {
137        &self.context
138    }
139
140    /// Static descriptor for the method being handled.
141    pub fn method(&self) -> &'static crate::MethodDescriptor {
142        self.context.method()
143    }
144
145    /// Request metadata borrowed from the inbound call.
146    pub fn metadata(&self) -> &'a [crate::MetadataEntry<'static>] {
147        self.context.metadata()
148    }
149
150    /// Wire-level request identifier for this call, when available.
151    pub fn request_id(&self) -> Option<crate::RequestId> {
152        self.context.request_id()
153    }
154
155    /// Virtual connection identifier for this call, when available.
156    pub fn connection_id(&self) -> Option<crate::ConnectionId> {
157        self.context.connection_id()
158    }
159
160    /// Per-request middleware extensions bag.
161    pub fn extensions(&self) -> &'a Extensions {
162        self.context.extensions()
163    }
164
165    /// Reflective view of the decoded argument tuple for this call.
166    pub const fn args(&self) -> Peek<'a, 'static> {
167        self.args
168    }
169}
170
171/// Owned context available when observing an outbound server response.
172#[derive(Clone, Debug)]
173pub struct ServerResponseContext {
174    method: &'static MethodDescriptor,
175    request_id: Option<RequestId>,
176    connection_id: Option<ConnectionId>,
177    extensions: Extensions,
178}
179
180impl ServerResponseContext {
181    /// Create a response context from transport identifiers and shared extensions.
182    pub const fn new(
183        method: &'static MethodDescriptor,
184        request_id: Option<RequestId>,
185        connection_id: Option<ConnectionId>,
186        extensions: Extensions,
187    ) -> Self {
188        Self {
189            method,
190            request_id,
191            connection_id,
192            extensions,
193        }
194    }
195
196    /// Static descriptor for the method being handled.
197    pub const fn method(&self) -> &'static MethodDescriptor {
198        self.method
199    }
200
201    /// Wire-level request identifier for this call, when available.
202    pub const fn request_id(&self) -> Option<RequestId> {
203        self.request_id
204    }
205
206    /// Virtual connection identifier for this call, when available.
207    pub const fn connection_id(&self) -> Option<ConnectionId> {
208        self.connection_id
209    }
210
211    /// Per-request middleware extensions bag.
212    pub const fn extensions(&self) -> &Extensions {
213        &self.extensions
214    }
215}
216
217/// Reflective view of one outbound server response payload.
218#[derive(Clone, Copy, Debug)]
219pub enum ServerResponsePayload<'a> {
220    Value(Peek<'a, 'static>),
221    PostcardBytes(&'a [u8]),
222}
223
224/// Middleware-facing view of one outbound server response.
225#[derive(Clone, Copy, Debug)]
226pub struct ServerResponse<'a> {
227    metadata: &'a [MetadataEntry<'a>],
228    payload: ServerResponsePayload<'a>,
229}
230
231impl<'a> ServerResponse<'a> {
232    pub fn new(response: &'a RequestResponse<'a>) -> Self {
233        let payload = match &response.ret {
234            Payload::Value { ptr, shape, .. } => {
235                let peek = unsafe { Peek::unchecked_new(*ptr, *shape) };
236                ServerResponsePayload::Value(peek)
237            }
238            Payload::PostcardBytes(bytes) => ServerResponsePayload::PostcardBytes(bytes),
239        };
240        Self {
241            metadata: &response.metadata,
242            payload,
243        }
244    }
245
246    pub const fn metadata(&self) -> &'a [MetadataEntry<'a>] {
247        self.metadata
248    }
249
250    pub const fn payload(&self) -> ServerResponsePayload<'a> {
251        self.payload
252    }
253
254    pub const fn payload_peek(&self) -> Option<Peek<'a, 'static>> {
255        match self.payload {
256            ServerResponsePayload::Value(peek) => Some(peek),
257            ServerResponsePayload::PostcardBytes(_) => None,
258        }
259    }
260}
261
262/// Observe inbound server requests before and after dispatch.
263pub trait ServerMiddleware: Send + Sync + 'static {
264    fn pre<'a>(&'a self, _request: ServerRequest<'_>) -> BoxMiddlewareFuture<'a> {
265        Box::pin(async {})
266    }
267
268    fn response<'a>(
269        &'a self,
270        _context: &ServerResponseContext,
271        _response: ServerResponse<'_>,
272    ) -> BoxMiddlewareFuture<'a> {
273        Box::pin(async {})
274    }
275
276    fn post<'a>(
277        &'a self,
278        _context: &RequestContext<'_>,
279        _outcome: ServerCallOutcome,
280    ) -> BoxMiddlewareFuture<'a> {
281        Box::pin(async {})
282    }
283}
284
285#[derive(Clone)]
286#[doc(hidden)]
287pub struct ServerCallOutcomeHandle {
288    outcome: Arc<Mutex<ServerCallOutcome>>,
289}
290
291impl ServerCallOutcomeHandle {
292    pub fn outcome(&self) -> ServerCallOutcome {
293        *self
294            .outcome
295            .lock()
296            .expect("server call outcome mutex poisoned")
297    }
298}
299
300#[doc(hidden)]
301pub struct ObservedReplySink<R> {
302    inner: Option<R>,
303    outcome: ServerCallOutcomeHandle,
304    response_context: ServerResponseContext,
305    middlewares: Vec<Arc<dyn ServerMiddleware>>,
306}
307
308#[doc(hidden)]
309pub fn observe_reply<R>(
310    reply: R,
311    response_context: ServerResponseContext,
312    middlewares: Vec<Arc<dyn ServerMiddleware>>,
313) -> (ObservedReplySink<R>, ServerCallOutcomeHandle) {
314    let outcome = ServerCallOutcomeHandle {
315        outcome: Arc::new(Mutex::new(ServerCallOutcome::DroppedWithoutReply)),
316    };
317    (
318        ObservedReplySink {
319            inner: Some(reply),
320            outcome: outcome.clone(),
321            response_context,
322            middlewares,
323        },
324        outcome,
325    )
326}
327
328impl<R> ReplySink for ObservedReplySink<R>
329where
330    R: ReplySink,
331{
332    async fn send_reply(mut self, response: RequestResponse<'_>) {
333        for middleware in self.middlewares.iter().rev() {
334            middleware
335                .response(&self.response_context, ServerResponse::new(&response))
336                .await;
337        }
338        *self
339            .outcome
340            .outcome
341            .lock()
342            .expect("server call outcome mutex poisoned") = ServerCallOutcome::Replied;
343        let reply = self
344            .inner
345            .take()
346            .expect("observed reply sink can only reply once");
347        reply.send_reply(response).await;
348    }
349
350    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
351        self.inner.as_ref().and_then(|reply| reply.channel_binder())
352    }
353
354    fn request_id(&self) -> Option<crate::RequestId> {
355        self.inner.as_ref().and_then(|reply| reply.request_id())
356    }
357
358    fn connection_id(&self) -> Option<crate::ConnectionId> {
359        self.inner.as_ref().and_then(|reply| reply.connection_id())
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::{Extensions, ServerCallOutcome, ServerRequest};
366    use crate::{RequestContext, method_descriptor};
367
368    #[test]
369    fn extensions_store_values_by_type() {
370        let extensions = Extensions::new();
371        assert!(!extensions.contains::<u32>());
372        assert_eq!(extensions.insert(41_u32), None);
373        assert!(extensions.contains::<u32>());
374        assert_eq!(extensions.get_cloned::<u32>(), Some(41));
375        let updated = extensions.with_mut::<u32, _>(|value| {
376            *value += 1;
377            *value
378        });
379        assert_eq!(updated, Some(42));
380        assert_eq!(extensions.get_cloned::<u32>(), Some(42));
381    }
382
383    #[test]
384    fn server_call_outcome_reports_reply_state() {
385        assert!(ServerCallOutcome::Replied.replied());
386        assert!(!ServerCallOutcome::DroppedWithoutReply.replied());
387    }
388
389    #[test]
390    fn server_request_exposes_context_and_decoded_args() {
391        let method =
392            method_descriptor::<(u32, u32), ()>("demo-service", "sum", &["left", "right"], None);
393        let metadata = [];
394        let extensions = Extensions::new();
395        let context = RequestContext::with_extensions(method, &metadata, &extensions);
396        let args = (7_u32, 35_u32);
397        let request = ServerRequest::new(context, facet_reflect::Peek::new(&args));
398
399        assert_eq!(request.method().method_name, "sum");
400        assert_eq!(request.metadata().len(), 0);
401        let tuple = request
402            .args()
403            .into_tuple()
404            .expect("decoded args should be a tuple");
405        let a = *tuple
406            .field(0)
407            .expect("first tuple field should exist")
408            .get::<u32>()
409            .expect("first tuple field should be u32");
410        let b = *tuple
411            .field(1)
412            .expect("second tuple field should exist")
413            .get::<u32>()
414            .expect("second tuple field should be u32");
415        assert_eq!((a, b), (7, 35));
416    }
417}