1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 future::Future,
5 pin::Pin,
6 sync::{Arc, Mutex},
7};
8
9use facet_reflect::Peek;
10
11use crate::{
12 ConnectionId, MetadataEntry, MethodDescriptor, Payload, ReplySink, RequestContext, RequestId,
13 RequestResponse,
14};
15
16#[derive(Clone, Debug, Default)]
18pub struct Extensions {
19 inner: Arc<Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
20}
21
22impl Extensions {
23 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn insert<T>(&self, value: T) -> Option<T>
30 where
31 T: Send + Sync + 'static,
32 {
33 let previous = self
34 .inner
35 .lock()
36 .expect("extensions mutex poisoned")
37 .insert(TypeId::of::<T>(), Box::new(value));
38 previous
39 .map(|boxed| {
40 boxed
41 .downcast::<T>()
42 .expect("extensions type id and boxed value disagreed")
43 })
44 .map(|boxed| *boxed)
45 }
46
47 pub fn contains<T>(&self) -> bool
49 where
50 T: Send + Sync + 'static,
51 {
52 self.inner
53 .lock()
54 .expect("extensions mutex poisoned")
55 .contains_key(&TypeId::of::<T>())
56 }
57
58 pub fn with<T, R>(&self, f: impl FnOnce(&T) -> R) -> Option<R>
60 where
61 T: Send + Sync + 'static,
62 {
63 let guard = self.inner.lock().expect("extensions mutex poisoned");
64 let value = guard.get(&TypeId::of::<T>())?;
65 let value = value
66 .downcast_ref::<T>()
67 .expect("extensions type id and boxed value disagreed");
68 Some(f(value))
69 }
70
71 pub fn with_mut<T, R>(&self, f: impl FnOnce(&mut T) -> R) -> Option<R>
73 where
74 T: Send + Sync + 'static,
75 {
76 let mut guard = self.inner.lock().expect("extensions mutex poisoned");
77 let value = guard.get_mut(&TypeId::of::<T>())?;
78 let value = value
79 .downcast_mut::<T>()
80 .expect("extensions type id and boxed value disagreed");
81 Some(f(value))
82 }
83
84 pub fn get_cloned<T>(&self) -> Option<T>
86 where
87 T: Clone + Send + Sync + 'static,
88 {
89 self.with(|value: &T| value.clone())
90 }
91}
92
93#[cfg(not(target_arch = "wasm32"))]
94pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
95#[cfg(target_arch = "wasm32")]
96pub type BoxMiddlewareFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum ServerCallOutcome {
101 Replied,
103 DroppedWithoutReply,
105}
106
107impl ServerCallOutcome {
108 pub fn replied(self) -> bool {
109 matches!(self, Self::Replied)
110 }
111}
112
113#[derive(Clone, Copy, Debug)]
124pub struct ServerRequest<'a> {
125 context: RequestContext<'a>,
126 args: Peek<'a, 'static>,
127}
128
129impl<'a> ServerRequest<'a> {
130 pub const fn new(context: RequestContext<'a>, args: Peek<'a, 'static>) -> Self {
132 Self { context, args }
133 }
134
135 pub const fn context(&self) -> &RequestContext<'a> {
137 &self.context
138 }
139
140 pub fn method(&self) -> &'static crate::MethodDescriptor {
142 self.context.method()
143 }
144
145 pub fn metadata(&self) -> &'a [crate::MetadataEntry<'static>] {
147 self.context.metadata()
148 }
149
150 pub fn request_id(&self) -> Option<crate::RequestId> {
152 self.context.request_id()
153 }
154
155 pub fn connection_id(&self) -> Option<crate::ConnectionId> {
157 self.context.connection_id()
158 }
159
160 pub fn extensions(&self) -> &'a Extensions {
162 self.context.extensions()
163 }
164
165 pub const fn args(&self) -> Peek<'a, 'static> {
167 self.args
168 }
169}
170
171#[derive(Clone, Debug)]
173pub struct ServerResponseContext {
174 method: &'static MethodDescriptor,
175 request_id: Option<RequestId>,
176 connection_id: Option<ConnectionId>,
177 extensions: Extensions,
178}
179
180impl ServerResponseContext {
181 pub const fn new(
183 method: &'static MethodDescriptor,
184 request_id: Option<RequestId>,
185 connection_id: Option<ConnectionId>,
186 extensions: Extensions,
187 ) -> Self {
188 Self {
189 method,
190 request_id,
191 connection_id,
192 extensions,
193 }
194 }
195
196 pub const fn method(&self) -> &'static MethodDescriptor {
198 self.method
199 }
200
201 pub const fn request_id(&self) -> Option<RequestId> {
203 self.request_id
204 }
205
206 pub const fn connection_id(&self) -> Option<ConnectionId> {
208 self.connection_id
209 }
210
211 pub const fn extensions(&self) -> &Extensions {
213 &self.extensions
214 }
215}
216
217#[derive(Clone, Copy, Debug)]
219pub enum ServerResponsePayload<'a> {
220 Value(Peek<'a, 'static>),
221 PostcardBytes(&'a [u8]),
222}
223
224#[derive(Clone, Copy, Debug)]
226pub struct ServerResponse<'a> {
227 metadata: &'a [MetadataEntry<'a>],
228 payload: ServerResponsePayload<'a>,
229}
230
231impl<'a> ServerResponse<'a> {
232 pub fn new(response: &'a RequestResponse<'a>) -> Self {
233 let payload = match &response.ret {
234 Payload::Value { ptr, shape, .. } => {
235 let peek = unsafe { Peek::unchecked_new(*ptr, *shape) };
236 ServerResponsePayload::Value(peek)
237 }
238 Payload::PostcardBytes(bytes) => ServerResponsePayload::PostcardBytes(bytes),
239 };
240 Self {
241 metadata: &response.metadata,
242 payload,
243 }
244 }
245
246 pub const fn metadata(&self) -> &'a [MetadataEntry<'a>] {
247 self.metadata
248 }
249
250 pub const fn payload(&self) -> ServerResponsePayload<'a> {
251 self.payload
252 }
253
254 pub const fn payload_peek(&self) -> Option<Peek<'a, 'static>> {
255 match self.payload {
256 ServerResponsePayload::Value(peek) => Some(peek),
257 ServerResponsePayload::PostcardBytes(_) => None,
258 }
259 }
260}
261
262pub trait ServerMiddleware: Send + Sync + 'static {
264 fn pre<'a>(&'a self, _request: ServerRequest<'_>) -> BoxMiddlewareFuture<'a> {
265 Box::pin(async {})
266 }
267
268 fn response<'a>(
269 &'a self,
270 _context: &ServerResponseContext,
271 _response: ServerResponse<'_>,
272 ) -> BoxMiddlewareFuture<'a> {
273 Box::pin(async {})
274 }
275
276 fn post<'a>(
277 &'a self,
278 _context: &RequestContext<'_>,
279 _outcome: ServerCallOutcome,
280 ) -> BoxMiddlewareFuture<'a> {
281 Box::pin(async {})
282 }
283}
284
285#[derive(Clone)]
286#[doc(hidden)]
287pub struct ServerCallOutcomeHandle {
288 outcome: Arc<Mutex<ServerCallOutcome>>,
289}
290
291impl ServerCallOutcomeHandle {
292 pub fn outcome(&self) -> ServerCallOutcome {
293 *self
294 .outcome
295 .lock()
296 .expect("server call outcome mutex poisoned")
297 }
298}
299
300#[doc(hidden)]
301pub struct ObservedReplySink<R> {
302 inner: Option<R>,
303 outcome: ServerCallOutcomeHandle,
304 response_context: ServerResponseContext,
305 middlewares: Vec<Arc<dyn ServerMiddleware>>,
306}
307
308#[doc(hidden)]
309pub fn observe_reply<R>(
310 reply: R,
311 response_context: ServerResponseContext,
312 middlewares: Vec<Arc<dyn ServerMiddleware>>,
313) -> (ObservedReplySink<R>, ServerCallOutcomeHandle) {
314 let outcome = ServerCallOutcomeHandle {
315 outcome: Arc::new(Mutex::new(ServerCallOutcome::DroppedWithoutReply)),
316 };
317 (
318 ObservedReplySink {
319 inner: Some(reply),
320 outcome: outcome.clone(),
321 response_context,
322 middlewares,
323 },
324 outcome,
325 )
326}
327
328impl<R> ReplySink for ObservedReplySink<R>
329where
330 R: ReplySink,
331{
332 async fn send_reply(mut self, response: RequestResponse<'_>) {
333 for middleware in self.middlewares.iter().rev() {
334 middleware
335 .response(&self.response_context, ServerResponse::new(&response))
336 .await;
337 }
338 *self
339 .outcome
340 .outcome
341 .lock()
342 .expect("server call outcome mutex poisoned") = ServerCallOutcome::Replied;
343 let reply = self
344 .inner
345 .take()
346 .expect("observed reply sink can only reply once");
347 reply.send_reply(response).await;
348 }
349
350 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
351 self.inner.as_ref().and_then(|reply| reply.channel_binder())
352 }
353
354 fn request_id(&self) -> Option<crate::RequestId> {
355 self.inner.as_ref().and_then(|reply| reply.request_id())
356 }
357
358 fn connection_id(&self) -> Option<crate::ConnectionId> {
359 self.inner.as_ref().and_then(|reply| reply.connection_id())
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::{Extensions, ServerCallOutcome, ServerRequest};
366 use crate::{RequestContext, method_descriptor};
367
368 #[test]
369 fn extensions_store_values_by_type() {
370 let extensions = Extensions::new();
371 assert!(!extensions.contains::<u32>());
372 assert_eq!(extensions.insert(41_u32), None);
373 assert!(extensions.contains::<u32>());
374 assert_eq!(extensions.get_cloned::<u32>(), Some(41));
375 let updated = extensions.with_mut::<u32, _>(|value| {
376 *value += 1;
377 *value
378 });
379 assert_eq!(updated, Some(42));
380 assert_eq!(extensions.get_cloned::<u32>(), Some(42));
381 }
382
383 #[test]
384 fn server_call_outcome_reports_reply_state() {
385 assert!(ServerCallOutcome::Replied.replied());
386 assert!(!ServerCallOutcome::DroppedWithoutReply.replied());
387 }
388
389 #[test]
390 fn server_request_exposes_context_and_decoded_args() {
391 let method =
392 method_descriptor::<(u32, u32), ()>("demo-service", "sum", &["left", "right"], None);
393 let metadata = [];
394 let extensions = Extensions::new();
395 let context = RequestContext::with_extensions(method, &metadata, &extensions);
396 let args = (7_u32, 35_u32);
397 let request = ServerRequest::new(context, facet_reflect::Peek::new(&args));
398
399 assert_eq!(request.method().method_name, "sum");
400 assert_eq!(request.metadata().len(), 0);
401 let tuple = request
402 .args()
403 .into_tuple()
404 .expect("decoded args should be a tuple");
405 let a = *tuple
406 .field(0)
407 .expect("first tuple field should exist")
408 .get::<u32>()
409 .expect("first tuple field should be u32");
410 let b = *tuple
411 .field(1)
412 .expect("second tuple field should exist")
413 .get::<u32>()
414 .expect("second tuple field should be u32");
415 assert_eq!((a, b), (7, 35));
416 }
417}