tycho_network/util/
router.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use tycho_util::FastHashMap;
5use tycho_util::futures::BoxFutureOrNoop;
6
7use crate::types::{BoxService, Service, ServiceExt};
8
9pub trait Routable {
10    #[inline]
11    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
12        std::iter::empty()
13    }
14
15    #[inline]
16    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
17        std::iter::empty()
18    }
19}
20
21pub struct RouterBuilder<Request, Q> {
22    inner: Inner<Request, Q>,
23}
24
25impl<Request, Q> RouterBuilder<Request, Q> {
26    pub fn route<S>(mut self, service: S) -> Self
27    where
28        S: Service<Request, QueryResponse = Q> + Routable + Send + Sync + 'static,
29    {
30        let index = self.inner.services.len();
31        for id in service.query_ids() {
32            let prev = self.inner.query_handlers.insert(id, index);
33            assert!(prev.is_none(), "duplicate query id: {id:08x}");
34        }
35        for id in service.message_ids() {
36            let prev = self.inner.message_handlers.insert(id, index);
37            assert!(prev.is_none(), "duplicate message id: {id:08x}");
38        }
39
40        self.inner.services.push(service.boxed());
41        self
42    }
43
44    pub fn build(self) -> Router<Request, Q> {
45        Router {
46            inner: Arc::new(self.inner),
47        }
48    }
49}
50
51impl<Request, Q> Default for RouterBuilder<Request, Q> {
52    fn default() -> Self {
53        Self {
54            inner: Inner {
55                services: Vec::new(),
56                query_handlers: FastHashMap::default(),
57                message_handlers: FastHashMap::default(),
58                _response: PhantomData,
59            },
60        }
61    }
62}
63
64pub struct Router<Request, Q> {
65    inner: Arc<Inner<Request, Q>>,
66}
67
68impl<Request, Q> Router<Request, Q> {
69    pub fn builder() -> RouterBuilder<Request, Q> {
70        RouterBuilder::default()
71    }
72}
73
74impl<Request, Q> Clone for Router<Request, Q> {
75    #[inline]
76    fn clone(&self) -> Self {
77        Self {
78            inner: self.inner.clone(),
79        }
80    }
81}
82
83impl<Request, Q> Service<Request> for Router<Request, Q>
84where
85    Request: Send + AsRef<[u8]> + 'static,
86    Q: Send + 'static,
87{
88    type QueryResponse = Q;
89    type OnQueryFuture = BoxFutureOrNoop<Option<Self::QueryResponse>>;
90    type OnMessageFuture = BoxFutureOrNoop<()>;
91
92    fn on_query(&self, req: Request) -> Self::OnQueryFuture {
93        match find_handler(&req, &self.inner.query_handlers, &self.inner.services) {
94            Some(service) => BoxFutureOrNoop::Boxed(service.on_query(req)),
95            None => BoxFutureOrNoop::Noop,
96        }
97    }
98
99    fn on_message(&self, req: Request) -> Self::OnMessageFuture {
100        match find_handler(&req, &self.inner.message_handlers, &self.inner.services) {
101            Some(service) => BoxFutureOrNoop::Boxed(service.on_message(req)),
102            None => BoxFutureOrNoop::Noop,
103        }
104    }
105}
106
107fn find_handler<'a, T: AsRef<[u8]>, S>(
108    req: &T,
109    indices: &FastHashMap<u32, usize>,
110    handlers: &'a [S],
111) -> Option<&'a S> {
112    if let Some(id) = read_le_u32(req.as_ref())
113        && let Some(&index) = indices.get(&id)
114    {
115        // NOTE: intentionally panics if index is out of bounds as it is
116        // an implementation error.
117        return Some(handlers.get(index).expect("index must be in bounds"));
118    }
119    None
120}
121
122struct Inner<Request, Q> {
123    services: Vec<BoxService<Request, Q>>,
124    query_handlers: FastHashMap<u32, usize>,
125    message_handlers: FastHashMap<u32, usize>,
126    _response: PhantomData<Q>,
127}
128
129fn read_le_u32(buf: &[u8]) -> Option<u32> {
130    if buf.len() >= 4 {
131        let mut bytes = [0; 4];
132        bytes.copy_from_slice(&buf[..4]);
133        Some(u32::from_le_bytes(bytes))
134    } else {
135        None
136    }
137}