1use std::sync::Arc;
2
3use crate::server_middleware::BoxMiddlewareFuture;
4use crate::{
5 BoxFut, CallResult, Caller, Extensions, Metadata, MetadataEntry, MetadataFlags, MetadataValue,
6 MethodDescriptor, MethodId, RequestCall, VoxError, ServiceDescriptor,
7};
8
9#[derive(Clone, Copy, Debug)]
11pub struct ClientContext<'a> {
12 method: Option<&'static MethodDescriptor>,
13 method_id: MethodId,
14 extensions: &'a Extensions,
15}
16
17impl<'a> ClientContext<'a> {
18 pub fn new(
19 method: Option<&'static MethodDescriptor>,
20 method_id: MethodId,
21 extensions: &'a Extensions,
22 ) -> Self {
23 Self {
24 method,
25 method_id,
26 extensions,
27 }
28 }
29
30 pub fn method(&self) -> Option<&'static MethodDescriptor> {
31 self.method
32 }
33
34 pub fn method_id(&self) -> MethodId {
35 self.method_id
36 }
37
38 pub fn extensions(&self) -> &'a Extensions {
39 self.extensions
40 }
41}
42
43pub struct ClientRequest<'call, 'state> {
48 call: &'state mut RequestCall<'call>,
49 owned_metadata: &'state mut OwnedMetadata,
50}
51
52impl<'call, 'state> ClientRequest<'call, 'state> {
53 pub(crate) fn new(
54 call: &'state mut RequestCall<'call>,
55 owned_metadata: &'state mut OwnedMetadata,
56 ) -> Self {
57 Self {
58 call,
59 owned_metadata,
60 }
61 }
62
63 pub fn call(&self) -> &RequestCall<'call> {
64 self.call
65 }
66
67 pub fn metadata(&self) -> &[MetadataEntry<'call>] {
68 &self.call.metadata
69 }
70
71 pub fn metadata_mut(&mut self) -> &mut Metadata<'call> {
72 &mut self.call.metadata
73 }
74
75 pub fn push_string_metadata(
76 &mut self,
77 key: &'static str,
78 value: impl Into<String>,
79 flags: MetadataFlags,
80 ) {
81 self.owned_metadata
82 .strings
83 .push(value.into().into_boxed_str());
84 let stored = self.owned_metadata.strings.last().unwrap();
85 let value: &'call str = unsafe { &*((&**stored) as *const str) };
89 self.call.metadata.push(MetadataEntry {
90 key,
91 value: MetadataValue::String(value),
92 flags,
93 });
94 }
95
96 pub fn push_bytes_metadata(
97 &mut self,
98 key: &'static str,
99 value: impl Into<Vec<u8>>,
100 flags: MetadataFlags,
101 ) {
102 self.owned_metadata
103 .bytes
104 .push(value.into().into_boxed_slice());
105 let stored = self.owned_metadata.bytes.last().unwrap();
106 let value: &'call [u8] = unsafe { &*((&**stored) as *const [u8]) };
108 self.call.metadata.push(MetadataEntry {
109 key,
110 value: MetadataValue::Bytes(value),
111 flags,
112 });
113 }
114
115 pub fn push_u64_metadata(&mut self, key: &'static str, value: u64, flags: MetadataFlags) {
116 self.call.metadata.push(MetadataEntry {
117 key,
118 value: MetadataValue::U64(value),
119 flags,
120 });
121 }
122}
123
124#[derive(Default)]
125pub(crate) struct OwnedMetadata {
126 strings: Vec<Box<str>>,
127 bytes: Vec<Box<[u8]>>,
128}
129
130#[derive(Clone, Copy)]
131pub enum ClientCallOutcome<'a> {
132 Response,
133 Error(&'a VoxError),
134}
135
136impl ClientCallOutcome<'_> {
137 pub fn is_ok(self) -> bool {
138 matches!(self, Self::Response)
139 }
140}
141
142pub trait ClientMiddleware: Send + Sync + 'static {
143 fn pre<'a, 'call>(
144 &'a self,
145 _context: &'a ClientContext<'a>,
146 _request: &'a mut ClientRequest<'call, 'a>,
147 ) -> BoxMiddlewareFuture<'a> {
148 Box::pin(async {})
149 }
150
151 fn post<'a>(
152 &'a self,
153 _context: &'a ClientContext<'a>,
154 _outcome: ClientCallOutcome<'a>,
155 ) -> BoxMiddlewareFuture<'a> {
156 Box::pin(async {})
157 }
158}
159
160#[derive(Clone)]
161pub struct MiddlewareCaller<C> {
162 caller: C,
163 service: &'static ServiceDescriptor,
164 middlewares: Vec<Arc<dyn ClientMiddleware>>,
165}
166
167impl<C> MiddlewareCaller<C> {
168 pub fn new(caller: C, service: &'static ServiceDescriptor) -> Self {
169 Self {
170 caller,
171 service,
172 middlewares: vec![],
173 }
174 }
175
176 pub fn with_middleware(mut self, middleware: impl ClientMiddleware) -> Self {
177 self.middlewares.push(Arc::new(middleware));
178 self
179 }
180}
181
182impl<C> Caller for MiddlewareCaller<C>
183where
184 C: Caller,
185{
186 async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
187 let extensions = Extensions::new();
188 let method = self.service.by_id(call.method_id);
189 let context = ClientContext::new(method, call.method_id, &extensions);
190 let mut owned_metadata = OwnedMetadata::default();
191 if !self.middlewares.is_empty() {
192 for middleware in &self.middlewares {
193 let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
194 middleware.pre(&context, &mut request).await;
195 }
196 }
197
198 let result = self.caller.call(call).await;
199 if !self.middlewares.is_empty() {
200 let outcome = match &result {
201 Ok(_) => ClientCallOutcome::Response,
202 Err(error) => ClientCallOutcome::Error(error),
203 };
204 for middleware in self.middlewares.iter().rev() {
205 middleware.post(&context, outcome).await;
206 }
207 }
208 result
209 }
210
211 fn closed(&self) -> BoxFut<'_, ()> {
212 self.caller.closed()
213 }
214
215 fn is_connected(&self) -> bool {
216 self.caller.is_connected()
217 }
218
219 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
220 self.caller.channel_binder()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use std::sync::{Arc, Mutex};
227
228 use crate::{Backing, Payload};
229
230 use super::{
231 BoxMiddlewareFuture, ClientCallOutcome, ClientContext, ClientMiddleware, ClientRequest,
232 MetadataFlags, MethodDescriptor, MethodId, MiddlewareCaller, OwnedMetadata, RequestCall,
233 };
234 use crate::{CallResult, Caller};
235 use crate::{RequestResponse, SelfRef};
236
237 #[test]
238 fn client_request_can_add_owned_metadata() {
239 let mut call = RequestCall {
240 method_id: MethodId(1),
241 metadata: vec![],
242 args: Payload::PostcardBytes(&[]),
243 schemas: Default::default(),
244 };
245 let mut owned = OwnedMetadata::default();
246 let mut request = ClientRequest::new(&mut call, &mut owned);
247 request.push_string_metadata("x-test", "value".to_string(), MetadataFlags::NONE);
248 request.push_bytes_metadata("x-bytes", vec![1, 2, 3], MetadataFlags::NONE);
249 request.push_u64_metadata("x-num", 7, MetadataFlags::NONE);
250
251 assert_eq!(request.metadata().len(), 3);
252 assert!(matches!(
253 request.metadata()[0].value,
254 crate::MetadataValue::String("value")
255 ));
256 assert!(matches!(
257 request.metadata()[1].value,
258 crate::MetadataValue::Bytes(bytes) if bytes == [1, 2, 3]
259 ));
260 assert!(matches!(
261 request.metadata()[2].value,
262 crate::MetadataValue::U64(7)
263 ));
264 }
265
266 #[derive(Clone)]
267 struct RecordingCaller {
268 seen_metadata: Arc<Mutex<Vec<String>>>,
269 }
270
271 impl Caller for RecordingCaller {
272 async fn call<'a>(&'a self, call: RequestCall<'a>) -> CallResult {
273 let seen = call
274 .metadata
275 .iter()
276 .map(|entry| match entry.value {
277 crate::MetadataValue::String(value) => format!("{}={value}", entry.key),
278 crate::MetadataValue::Bytes(bytes) => {
279 format!("{}=<{} bytes>", entry.key, bytes.len())
280 }
281 crate::MetadataValue::U64(value) => format!("{}={value}", entry.key),
282 })
283 .collect::<Vec<_>>();
284 *self
285 .seen_metadata
286 .lock()
287 .expect("seen metadata mutex poisoned") = seen;
288
289 Ok(crate::WithTracker {
290 value: SelfRef::owning(
291 Backing::Boxed(Box::<[u8]>::default()),
292 RequestResponse {
293 metadata: vec![],
294 ret: Payload::PostcardBytes(&[]),
295 schemas: Default::default(),
296 },
297 ),
298 tracker: std::sync::Arc::new(crate::SchemaRecvTracker::new()),
299 })
300 }
301 }
302
303 #[derive(Clone)]
304 struct InjectMetadata;
305
306 impl ClientMiddleware for InjectMetadata {
307 fn pre<'a, 'call>(
308 &'a self,
309 context: &'a ClientContext<'a>,
310 request: &'a mut ClientRequest<'call, 'a>,
311 ) -> BoxMiddlewareFuture<'a> {
312 Box::pin(async move {
313 context.extensions().insert(41_u32);
314 request.push_string_metadata("x-test", "value".to_string(), MetadataFlags::NONE);
315 })
316 }
317
318 fn post<'a>(
319 &'a self,
320 context: &'a ClientContext<'a>,
321 outcome: ClientCallOutcome<'a>,
322 ) -> BoxMiddlewareFuture<'a> {
323 Box::pin(async move {
324 assert_eq!(context.extensions().get_cloned::<u32>(), Some(41));
325 assert!(outcome.is_ok());
326 })
327 }
328 }
329
330 #[tokio::test]
331 async fn middleware_caller_runs_hooks_and_mutates_metadata() {
332 static METHOD: MethodDescriptor = MethodDescriptor {
333 id: MethodId(7),
334 service_name: "Audit",
335 method_name: "record",
336 args_shape: <() as facet::Facet<'static>>::SHAPE,
337 args: &[],
338 return_shape: <() as facet::Facet<'static>>::SHAPE,
339 retry: crate::RetryPolicy::VOLATILE,
340 doc: None,
341 };
342 static SERVICE: crate::ServiceDescriptor = crate::ServiceDescriptor {
343 service_name: "Audit",
344 methods: &[&METHOD],
345 doc: None,
346 };
347
348 let seen_metadata = Arc::new(Mutex::new(Vec::new()));
349 let caller = MiddlewareCaller::new(
350 RecordingCaller {
351 seen_metadata: Arc::clone(&seen_metadata),
352 },
353 &SERVICE,
354 )
355 .with_middleware(InjectMetadata);
356
357 let response: CallResult = caller
358 .call(RequestCall {
359 method_id: MethodId(7),
360 metadata: vec![],
361 args: Payload::PostcardBytes(&[]),
362 schemas: Default::default(),
363 })
364 .await;
365
366 assert!(response.is_ok());
367 assert_eq!(
368 *seen_metadata.lock().expect("seen metadata mutex poisoned"),
369 vec!["x-test=value".to_string()]
370 );
371 }
372}