volo_grpc/server/
router.rs

1use std::{
2    fmt,
3    sync::atomic::{AtomicU32, Ordering},
4};
5
6use http_body::Body as HttpBody;
7use motore::{BoxCloneService, Service};
8use rustc_hash::FxHashMap;
9use volo::Unwrap;
10
11use super::NamedService;
12use crate::{Request, Response, Status, body::BoxBody, context::ServerContext};
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
15struct RouteId(u32);
16
17impl RouteId {
18    fn next() -> Self {
19        // `AtomicU64` isn't supported on all platforms
20        static ID: AtomicU32 = AtomicU32::new(0);
21        let id = ID.fetch_add(1, Ordering::Relaxed);
22        if id == u32::MAX {
23            panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
24        }
25        Self(id)
26    }
27}
28
29#[derive(Default)]
30pub struct Router<B = BoxBody> {
31    routes:
32        FxHashMap<RouteId, BoxCloneService<ServerContext, Request<B>, Response<BoxBody>, Status>>,
33    node: matchit::Router<RouteId>,
34}
35
36impl<B> Clone for Router<B> {
37    fn clone(&self) -> Self {
38        Self {
39            routes: self.routes.clone(),
40            node: self.node.clone(),
41        }
42    }
43}
44
45impl<B> Router<B>
46where
47    B: HttpBody + 'static,
48{
49    pub fn new() -> Self {
50        Self {
51            routes: Default::default(),
52            node: Default::default(),
53        }
54    }
55
56    pub fn add_service<S>(mut self, service: S) -> Self
57    where
58        S: Service<ServerContext, Request<B>, Response = Response<BoxBody>, Error = Status>
59            + NamedService
60            + Clone
61            + Send
62            + Sync
63            + 'static,
64    {
65        let path = format!("/{}/{{*rest}}", S::NAME);
66
67        if path.is_empty() {
68            panic!("[VOLO] Paths must start with a `/`. Use \"/\" for root routes");
69        } else if !path.starts_with('/') {
70            panic!("[VOLO] Paths must start with a `/`");
71        }
72
73        let id = RouteId::next();
74
75        self.set_node(path, id);
76
77        self.routes.insert(id, BoxCloneService::new(service));
78
79        self
80    }
81
82    #[track_caller]
83    fn set_node(&mut self, path: String, id: RouteId) {
84        if let Err(err) = self.node.insert(path, id) {
85            panic!("[VOLO] Invalid route: {err}");
86        }
87    }
88}
89
90impl<B> Service<ServerContext, Request<B>> for Router<B>
91where
92    B: HttpBody + Send,
93{
94    type Response = Response<BoxBody>;
95    type Error = Status;
96
97    async fn call(
98        &self,
99        cx: &mut ServerContext,
100        req: Request<B>,
101    ) -> Result<Self::Response, Self::Error> {
102        let path = cx.rpc_info.method();
103        match self.node.at(path) {
104            Ok(match_) => {
105                let id = match_.value;
106                let route = self.routes.get(id).volo_unwrap().clone();
107                route.call(cx, req).await
108            }
109            Err(err) => Err(Status::unimplemented(err.to_string())),
110        }
111    }
112}
113
114impl<B> fmt::Debug for Router<B> {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("Router")
117            .field("routes", &self.routes)
118            .finish()
119    }
120}