Skip to main content

vox_types/
client_middleware.rs

1use std::sync::Arc;
2
3use crate::server_middleware::BoxMiddlewareFuture;
4use crate::{
5    BoxFut, CallResult, Caller, Extensions, Metadata, MetadataEntry, MetadataFlags, MetadataValue,
6    MethodDescriptor, MethodId, RequestCall, VoxError, ServiceDescriptor,
7};
8
9/// Borrowed per-call context exposed to client middleware.
10#[derive(Clone, Copy, Debug)]
11pub struct ClientContext<'a> {
12    method: Option<&'static MethodDescriptor>,
13    method_id: MethodId,
14    extensions: &'a Extensions,
15}
16
17impl<'a> ClientContext<'a> {
18    pub fn new(
19        method: Option<&'static MethodDescriptor>,
20        method_id: MethodId,
21        extensions: &'a Extensions,
22    ) -> Self {
23        Self {
24            method,
25            method_id,
26            extensions,
27        }
28    }
29
30    pub fn method(&self) -> Option<&'static MethodDescriptor> {
31        self.method
32    }
33
34    pub fn method_id(&self) -> MethodId {
35        self.method_id
36    }
37
38    pub fn extensions(&self) -> &'a Extensions {
39        self.extensions
40    }
41}
42
43/// Borrowed request wrapper exposed to client middleware.
44///
45/// This allows middleware to add dynamic metadata while keeping the backing
46/// storage alive until the wrapped caller finishes sending the request.
47pub struct ClientRequest<'call, 'state> {
48    call: &'state mut RequestCall<'call>,
49    owned_metadata: &'state mut OwnedMetadata,
50}
51
52impl<'call, 'state> ClientRequest<'call, 'state> {
53    pub(crate) fn new(
54        call: &'state mut RequestCall<'call>,
55        owned_metadata: &'state mut OwnedMetadata,
56    ) -> Self {
57        Self {
58            call,
59            owned_metadata,
60        }
61    }
62
63    pub fn call(&self) -> &RequestCall<'call> {
64        self.call
65    }
66
67    pub fn metadata(&self) -> &[MetadataEntry<'call>] {
68        &self.call.metadata
69    }
70
71    pub fn metadata_mut(&mut self) -> &mut Metadata<'call> {
72        &mut self.call.metadata
73    }
74
75    pub fn push_string_metadata(
76        &mut self,
77        key: &'static str,
78        value: impl Into<String>,
79        flags: MetadataFlags,
80    ) {
81        self.owned_metadata
82            .strings
83            .push(value.into().into_boxed_str());
84        let stored = self.owned_metadata.strings.last().unwrap();
85        // SAFETY: The boxed string is heap-allocated (stable address) and owned by
86        // `owned_metadata`, which lives in the same stack frame as `call` in
87        // MiddlewareCaller::call. It won't be dropped until after `call` is consumed.
88        let value: &'call str = unsafe { &*((&**stored) as *const str) };
89        self.call.metadata.push(MetadataEntry {
90            key,
91            value: MetadataValue::String(value),
92            flags,
93        });
94    }
95
96    pub fn push_bytes_metadata(
97        &mut self,
98        key: &'static str,
99        value: impl Into<Vec<u8>>,
100        flags: MetadataFlags,
101    ) {
102        self.owned_metadata
103            .bytes
104            .push(value.into().into_boxed_slice());
105        let stored = self.owned_metadata.bytes.last().unwrap();
106        // SAFETY: same reasoning as push_string_metadata above.
107        let value: &'call [u8] = unsafe { &*((&**stored) as *const [u8]) };
108        self.call.metadata.push(MetadataEntry {
109            key,
110            value: MetadataValue::Bytes(value),
111            flags,
112        });
113    }
114
115    pub fn push_u64_metadata(&mut self, key: &'static str, value: u64, flags: MetadataFlags) {
116        self.call.metadata.push(MetadataEntry {
117            key,
118            value: MetadataValue::U64(value),
119            flags,
120        });
121    }
122}
123
124#[derive(Default)]
125pub(crate) struct OwnedMetadata {
126    strings: Vec<Box<str>>,
127    bytes: Vec<Box<[u8]>>,
128}
129
130#[derive(Clone, Copy)]
131pub enum ClientCallOutcome<'a> {
132    Response,
133    Error(&'a VoxError),
134}
135
136impl ClientCallOutcome<'_> {
137    pub fn is_ok(self) -> bool {
138        matches!(self, Self::Response)
139    }
140}
141
142pub trait ClientMiddleware: Send + Sync + 'static {
143    fn pre<'a, 'call>(
144        &'a self,
145        _context: &'a ClientContext<'a>,
146        _request: &'a mut ClientRequest<'call, 'a>,
147    ) -> BoxMiddlewareFuture<'a> {
148        Box::pin(async {})
149    }
150
151    fn post<'a>(
152        &'a self,
153        _context: &'a ClientContext<'a>,
154        _outcome: ClientCallOutcome<'a>,
155    ) -> BoxMiddlewareFuture<'a> {
156        Box::pin(async {})
157    }
158}
159
160#[derive(Clone)]
161pub struct MiddlewareCaller<C> {
162    caller: C,
163    service: &'static ServiceDescriptor,
164    middlewares: Vec<Arc<dyn ClientMiddleware>>,
165}
166
167impl<C> MiddlewareCaller<C> {
168    pub fn new(caller: C, service: &'static ServiceDescriptor) -> Self {
169        Self {
170            caller,
171            service,
172            middlewares: vec![],
173        }
174    }
175
176    pub fn with_middleware(mut self, middleware: impl ClientMiddleware) -> Self {
177        self.middlewares.push(Arc::new(middleware));
178        self
179    }
180}
181
182impl<C> Caller for MiddlewareCaller<C>
183where
184    C: Caller,
185{
186    async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
187        let extensions = Extensions::new();
188        let method = self.service.by_id(call.method_id);
189        let context = ClientContext::new(method, call.method_id, &extensions);
190        let mut owned_metadata = OwnedMetadata::default();
191        if !self.middlewares.is_empty() {
192            for middleware in &self.middlewares {
193                let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
194                middleware.pre(&context, &mut request).await;
195            }
196        }
197
198        let result = self.caller.call(call).await;
199        if !self.middlewares.is_empty() {
200            let outcome = match &result {
201                Ok(_) => ClientCallOutcome::Response,
202                Err(error) => ClientCallOutcome::Error(error),
203            };
204            for middleware in self.middlewares.iter().rev() {
205                middleware.post(&context, outcome).await;
206            }
207        }
208        result
209    }
210
211    fn closed(&self) -> BoxFut<'_, ()> {
212        self.caller.closed()
213    }
214
215    fn is_connected(&self) -> bool {
216        self.caller.is_connected()
217    }
218
219    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
220        self.caller.channel_binder()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use std::sync::{Arc, Mutex};
227
228    use crate::{Backing, Payload};
229
230    use super::{
231        BoxMiddlewareFuture, ClientCallOutcome, ClientContext, ClientMiddleware, ClientRequest,
232        MetadataFlags, MethodDescriptor, MethodId, MiddlewareCaller, OwnedMetadata, RequestCall,
233    };
234    use crate::{CallResult, Caller};
235    use crate::{RequestResponse, SelfRef};
236
237    #[test]
238    fn client_request_can_add_owned_metadata() {
239        let mut call = RequestCall {
240            method_id: MethodId(1),
241            metadata: vec![],
242            args: Payload::PostcardBytes(&[]),
243            schemas: Default::default(),
244        };
245        let mut owned = OwnedMetadata::default();
246        let mut request = ClientRequest::new(&mut call, &mut owned);
247        request.push_string_metadata("x-test", "value".to_string(), MetadataFlags::NONE);
248        request.push_bytes_metadata("x-bytes", vec![1, 2, 3], MetadataFlags::NONE);
249        request.push_u64_metadata("x-num", 7, MetadataFlags::NONE);
250
251        assert_eq!(request.metadata().len(), 3);
252        assert!(matches!(
253            request.metadata()[0].value,
254            crate::MetadataValue::String("value")
255        ));
256        assert!(matches!(
257            request.metadata()[1].value,
258            crate::MetadataValue::Bytes(bytes) if bytes == [1, 2, 3]
259        ));
260        assert!(matches!(
261            request.metadata()[2].value,
262            crate::MetadataValue::U64(7)
263        ));
264    }
265
266    #[derive(Clone)]
267    struct RecordingCaller {
268        seen_metadata: Arc<Mutex<Vec<String>>>,
269    }
270
271    impl Caller for RecordingCaller {
272        async fn call<'a>(&'a self, call: RequestCall<'a>) -> CallResult {
273            let seen = call
274                .metadata
275                .iter()
276                .map(|entry| match entry.value {
277                    crate::MetadataValue::String(value) => format!("{}={value}", entry.key),
278                    crate::MetadataValue::Bytes(bytes) => {
279                        format!("{}=<{} bytes>", entry.key, bytes.len())
280                    }
281                    crate::MetadataValue::U64(value) => format!("{}={value}", entry.key),
282                })
283                .collect::<Vec<_>>();
284            *self
285                .seen_metadata
286                .lock()
287                .expect("seen metadata mutex poisoned") = seen;
288
289            Ok(crate::WithTracker {
290                value: SelfRef::owning(
291                    Backing::Boxed(Box::<[u8]>::default()),
292                    RequestResponse {
293                        metadata: vec![],
294                        ret: Payload::PostcardBytes(&[]),
295                        schemas: Default::default(),
296                    },
297                ),
298                tracker: std::sync::Arc::new(crate::SchemaRecvTracker::new()),
299            })
300        }
301    }
302
303    #[derive(Clone)]
304    struct InjectMetadata;
305
306    impl ClientMiddleware for InjectMetadata {
307        fn pre<'a, 'call>(
308            &'a self,
309            context: &'a ClientContext<'a>,
310            request: &'a mut ClientRequest<'call, 'a>,
311        ) -> BoxMiddlewareFuture<'a> {
312            Box::pin(async move {
313                context.extensions().insert(41_u32);
314                request.push_string_metadata("x-test", "value".to_string(), MetadataFlags::NONE);
315            })
316        }
317
318        fn post<'a>(
319            &'a self,
320            context: &'a ClientContext<'a>,
321            outcome: ClientCallOutcome<'a>,
322        ) -> BoxMiddlewareFuture<'a> {
323            Box::pin(async move {
324                assert_eq!(context.extensions().get_cloned::<u32>(), Some(41));
325                assert!(outcome.is_ok());
326            })
327        }
328    }
329
330    #[tokio::test]
331    async fn middleware_caller_runs_hooks_and_mutates_metadata() {
332        static METHOD: MethodDescriptor = MethodDescriptor {
333            id: MethodId(7),
334            service_name: "Audit",
335            method_name: "record",
336            args_shape: <() as facet::Facet<'static>>::SHAPE,
337            args: &[],
338            return_shape: <() as facet::Facet<'static>>::SHAPE,
339            retry: crate::RetryPolicy::VOLATILE,
340            doc: None,
341        };
342        static SERVICE: crate::ServiceDescriptor = crate::ServiceDescriptor {
343            service_name: "Audit",
344            methods: &[&METHOD],
345            doc: None,
346        };
347
348        let seen_metadata = Arc::new(Mutex::new(Vec::new()));
349        let caller = MiddlewareCaller::new(
350            RecordingCaller {
351                seen_metadata: Arc::clone(&seen_metadata),
352            },
353            &SERVICE,
354        )
355        .with_middleware(InjectMetadata);
356
357        let response: CallResult = caller
358            .call(RequestCall {
359                method_id: MethodId(7),
360                metadata: vec![],
361                args: Payload::PostcardBytes(&[]),
362                schemas: Default::default(),
363            })
364            .await;
365
366        assert!(response.is_ok());
367        assert_eq!(
368            *seen_metadata.lock().expect("seen metadata mutex poisoned"),
369            vec!["x-test=value".to_string()]
370        );
371    }
372}