tycho_network/util/
router.rs1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use tycho_util::FastHashMap;
5use tycho_util::futures::BoxFutureOrNoop;
6
7use crate::types::{BoxService, Service, ServiceExt};
8
9pub trait Routable {
10 #[inline]
11 fn query_ids(&self) -> impl IntoIterator<Item = u32> {
12 std::iter::empty()
13 }
14
15 #[inline]
16 fn message_ids(&self) -> impl IntoIterator<Item = u32> {
17 std::iter::empty()
18 }
19}
20
21pub struct RouterBuilder<Request, Q> {
22 inner: Inner<Request, Q>,
23}
24
25impl<Request, Q> RouterBuilder<Request, Q> {
26 pub fn route<S>(mut self, service: S) -> Self
27 where
28 S: Service<Request, QueryResponse = Q> + Routable + Send + Sync + 'static,
29 {
30 let index = self.inner.services.len();
31 for id in service.query_ids() {
32 let prev = self.inner.query_handlers.insert(id, index);
33 assert!(prev.is_none(), "duplicate query id: {id:08x}");
34 }
35 for id in service.message_ids() {
36 let prev = self.inner.message_handlers.insert(id, index);
37 assert!(prev.is_none(), "duplicate message id: {id:08x}");
38 }
39
40 self.inner.services.push(service.boxed());
41 self
42 }
43
44 pub fn build(self) -> Router<Request, Q> {
45 Router {
46 inner: Arc::new(self.inner),
47 }
48 }
49}
50
51impl<Request, Q> Default for RouterBuilder<Request, Q> {
52 fn default() -> Self {
53 Self {
54 inner: Inner {
55 services: Vec::new(),
56 query_handlers: FastHashMap::default(),
57 message_handlers: FastHashMap::default(),
58 _response: PhantomData,
59 },
60 }
61 }
62}
63
64pub struct Router<Request, Q> {
65 inner: Arc<Inner<Request, Q>>,
66}
67
68impl<Request, Q> Router<Request, Q> {
69 pub fn builder() -> RouterBuilder<Request, Q> {
70 RouterBuilder::default()
71 }
72}
73
74impl<Request, Q> Clone for Router<Request, Q> {
75 #[inline]
76 fn clone(&self) -> Self {
77 Self {
78 inner: self.inner.clone(),
79 }
80 }
81}
82
83impl<Request, Q> Service<Request> for Router<Request, Q>
84where
85 Request: Send + AsRef<[u8]> + 'static,
86 Q: Send + 'static,
87{
88 type QueryResponse = Q;
89 type OnQueryFuture = BoxFutureOrNoop<Option<Self::QueryResponse>>;
90 type OnMessageFuture = BoxFutureOrNoop<()>;
91
92 fn on_query(&self, req: Request) -> Self::OnQueryFuture {
93 match find_handler(&req, &self.inner.query_handlers, &self.inner.services) {
94 Some(service) => BoxFutureOrNoop::Boxed(service.on_query(req)),
95 None => BoxFutureOrNoop::Noop,
96 }
97 }
98
99 fn on_message(&self, req: Request) -> Self::OnMessageFuture {
100 match find_handler(&req, &self.inner.message_handlers, &self.inner.services) {
101 Some(service) => BoxFutureOrNoop::Boxed(service.on_message(req)),
102 None => BoxFutureOrNoop::Noop,
103 }
104 }
105}
106
107fn find_handler<'a, T: AsRef<[u8]>, S>(
108 req: &T,
109 indices: &FastHashMap<u32, usize>,
110 handlers: &'a [S],
111) -> Option<&'a S> {
112 if let Some(id) = read_le_u32(req.as_ref())
113 && let Some(&index) = indices.get(&id)
114 {
115 return Some(handlers.get(index).expect("index must be in bounds"));
118 }
119 None
120}
121
122struct Inner<Request, Q> {
123 services: Vec<BoxService<Request, Q>>,
124 query_handlers: FastHashMap<u32, usize>,
125 message_handlers: FastHashMap<u32, usize>,
126 _response: PhantomData<Q>,
127}
128
129fn read_le_u32(buf: &[u8]) -> Option<u32> {
130 if buf.len() >= 4 {
131 let mut bytes = [0; 4];
132 bytes.copy_from_slice(&buf[..4]);
133 Some(u32::from_le_bytes(bytes))
134 } else {
135 None
136 }
137}