Skip to main content

zlink_core/server/
mod.rs

1pub(crate) mod listener;
2mod select_all;
3pub mod service;
4
5use alloc::vec::Vec;
6use futures_util::{FutureExt, StreamExt};
7use select_all::SelectAll;
8use service::MethodReply;
9
10use crate::{connection::Socket, Call, Connection, Reply};
11
12/// A server.
13///
14/// The server listens for incoming connections and handles method calls using a service.
15#[derive(Debug)]
16pub struct Server<Listener, Service> {
17    listener: Option<Listener>,
18    service: Service,
19}
20
21impl<Listener, Service> Server<Listener, Service>
22where
23    Listener: listener::Listener,
24    Service: service::Service<Listener::Socket>,
25{
26    /// Create a new server that serves `service` to incoming connections from `listener`.
27    pub fn new(listener: Listener, service: Service) -> Self {
28        Self {
29            listener: Some(listener),
30            service,
31        }
32    }
33
34    /// Run the server.
35    ///
36    /// # Caveats
37    ///
38    /// Due to [a bug in the rust compiler][abrc], the future returned by this method can not be
39    /// treated as `Send`, even if all the specific types involved are `Send`. A major consequence
40    /// of this fact unfortunately, is that it can not be spawned in a task of a multi-threaded
41    /// runtime. For example, you can not currently do `tokio::spawn(server.run())`.
42    ///
43    /// Fortunately, there are easy workarounds for this. You can either:
44    ///
45    /// * Use a thread-local runtime (for example [`tokio::runtime::LocalRuntime`] or
46    ///   [`tokio::task::LocalSet`]) to run the server in a local task, perhaps in a separate
47    ///   thread.
48    /// * Use some common API to run multiple futures at once, such as [`futures::select!`] or
49    ///   [`tokio::select!`].
50    ///
51    /// Most importantly, this is most likely a temporary issue and will be fixed in the future. 😊
52    ///
53    /// [abrc]: https://github.com/rust-lang/rust/issues/100013
54    /// [`tokio::runtime::LocalRuntime`]: https://docs.rs/tokio/latest/tokio/runtime/struct.LocalRuntime.html
55    /// [`tokio::task::LocalSet`]: https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html
56    /// [`futures::select!`]: https://docs.rs/futures/latest/futures/macro.select.html
57    /// [`tokio::select!`]: https://docs.rs/tokio/latest/tokio/macro.select.html
58    pub async fn run(mut self) -> crate::Result<()> {
59        let mut listener = self.listener.take().unwrap();
60        let mut connections = Vec::new();
61        let mut reply_streams = Vec::<ReplyStream<Service::ReplyStream, Listener::Socket>>::new();
62        let mut reply_stream_futures = Vec::new();
63        // Vec for futures from `Connection::receive_call`. Reused across iterations to avoid
64        // per-iteration allocations.
65        let mut read_futures = Vec::new();
66        let mut last_reply_stream_winner = None;
67        let mut last_method_call_winner = None;
68
69        loop {
70            // We re-populate the `reply_stream_futures` in each iteration so we must clear it
71            // first.
72            reply_stream_futures.clear();
73            {
74                // SAFETY: Rust has no way to know that we don't re-use the mutable references in
75                // each iteration (since we clear the `reply_stream_futures` vector) so we need to
76                // go through a pointer to work around this.
77                let reply_streams: &mut Vec<ReplyStream<Service::ReplyStream, Listener::Socket>> =
78                    unsafe { &mut *(&mut reply_streams as *mut Vec<_>) };
79                reply_stream_futures.extend(reply_streams.iter_mut().map(|s| s.stream.next()));
80            }
81            let start_index = last_reply_stream_winner.map(|idx| idx + 1);
82            let mut reply_stream_select_all = SelectAll::new(start_index);
83            for future in reply_stream_futures.iter_mut() {
84                reply_stream_select_all.push(future);
85            }
86
87            // Prepare futures for reading method calls from connections.
88            read_futures.clear();
89            {
90                // SAFETY: Same as above - mutable references are not reused across iterations.
91                let connections: &mut Vec<Connection<Listener::Socket>> =
92                    unsafe { &mut *(&mut connections as *mut Vec<_>) };
93                read_futures.extend(connections.iter_mut().map(|c| c.receive_call()));
94            }
95            let mut read_select_all = SelectAll::new(last_method_call_winner.map(|idx| idx + 1));
96            for future in &mut read_futures {
97                // SAFETY: Futures in `read_futures` are dropped in place via `clear()` at the
98                // start of the next iteration, never moved while pinned.
99                unsafe {
100                    read_select_all.push_unchecked(future);
101                }
102            }
103
104            futures_util::select_biased! {
105                // 1. Accept a new connection.
106                conn = listener.accept().fuse() => {
107                    connections.push(conn?);
108                }
109                // 2. Read method calls from the existing connections and handle them.
110                (idx, result) = read_select_all.fuse() => {
111                        #[cfg(feature = "std")]
112                        let (call, fds) = match result {
113                            Ok((call, fds)) => (Ok(call), fds),
114                            Err(e) => (Err(e), alloc::vec![]),
115                        };
116                        #[cfg(not(feature = "std"))]
117                        let call = result;
118                        last_method_call_winner = Some(idx);
119
120                        let mut stream = None;
121                        let mut remove = true;
122                        match call {
123                            Ok(call) => {
124                                #[cfg(feature = "std")]
125                                let result =
126                                    self.handle_call(call, &mut connections[idx], fds).await;
127                                #[cfg(not(feature = "std"))]
128                                let result =
129                                    self.handle_call(call, &mut connections[idx]).await;
130                                match result {
131                                    Ok(None) => remove = false,
132                                    Ok(Some(s)) => stream = Some(s),
133                                    Err(e) => warn!("Error writing to connection: {:?}", e),
134                                }
135                            }
136                            Err(e) => warn!("Error reading from socket: {:?}", e),
137                        }
138
139                        if stream.is_some() || remove {
140                            let conn = connections.swap_remove(idx);
141
142                            if let Some(stream) = stream {
143                                reply_streams.push(ReplyStream::new(stream, conn));
144                            }
145                        }
146                }
147                // 3. Read replies from the reply streams and send them off.
148                reply = reply_stream_select_all.fuse() => {
149                    let (idx, item) = reply;
150                    last_reply_stream_winner = Some(idx);
151                    let id = reply_streams[idx].conn.id();
152
153                    match item {
154                        Some(item) => {
155                            #[cfg(feature = "std")]
156                            let (reply, fds) = item;
157                            #[cfg(not(feature = "std"))]
158                            let reply = item;
159
160                            #[cfg(feature = "std")]
161                            let send_result =
162                                reply_streams[idx].conn.send_reply(&reply, fds).await;
163                            #[cfg(not(feature = "std"))]
164                            let send_result = reply_streams[idx].conn.send_reply(&reply).await;
165                            if let Err(e) = send_result {
166                                warn!("Error writing to client {}: {:?}", id, e);
167                                reply_streams.swap_remove(idx);
168                            }
169                        }
170                        None => {
171                            trace!("Stream closed for client {}", id);
172                            let stream = reply_streams.swap_remove(idx);
173                            connections.push(stream.conn);
174                        }
175                    }
176                }
177            }
178        }
179    }
180
181    async fn handle_call(
182        &mut self,
183        call: Call<Service::MethodCall<'_>>,
184        conn: &mut Connection<Listener::Socket>,
185        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
186    ) -> crate::Result<Option<Service::ReplyStream>> {
187        let mut stream = None;
188
189        #[cfg(feature = "std")]
190        let (reply, reply_fds) = self.service.handle(&call, conn, fds).await;
191        #[cfg(not(feature = "std"))]
192        let reply = self.service.handle(&call, conn).await;
193
194        match reply {
195            // Don't send replies or errors for oneway calls.
196            MethodReply::Single(_) | MethodReply::Error(_) if call.oneway() => (),
197            MethodReply::Single(params) => {
198                let reply = Reply::new(params).set_continues(Some(false));
199                #[cfg(feature = "std")]
200                conn.send_reply(&reply, reply_fds).await?;
201                #[cfg(not(feature = "std"))]
202                conn.send_reply(&reply).await?;
203            }
204            #[cfg(feature = "std")]
205            MethodReply::Error(err) => conn.send_error(&err, reply_fds).await?,
206            #[cfg(not(feature = "std"))]
207            MethodReply::Error(err) => conn.send_error(&err).await?,
208            MethodReply::Multi(s) => {
209                trace!("Client {} now turning into a reply stream", conn.id());
210                stream = Some(s)
211            }
212        }
213
214        Ok(stream)
215    }
216}
217
218/// Method reply stream and connection pair.
219#[derive(Debug)]
220struct ReplyStream<St, Sock: Socket> {
221    stream: St,
222    conn: Connection<Sock>,
223}
224
225impl<St, Sock> ReplyStream<St, Sock>
226where
227    Sock: Socket,
228{
229    fn new(stream: St, conn: Connection<Sock>) -> Self {
230        Self { stream, conn }
231    }
232}