Skip to main content

zlink_core/server/
mod.rs

1pub(crate) mod listener;
2mod select_all;
3pub mod service;
4
5mod infallible;
6
7use alloc::vec::Vec;
8use futures_util::{FutureExt, StreamExt};
9use select_all::SelectAll;
10use service::MethodReply;
11
12use crate::{Call, Connection, Reply, connection::Socket};
13
14/// A server.
15///
16/// The server listens for incoming connections and handles method calls using a service.
17#[derive(Debug)]
18pub struct Server<Listener, Service> {
19    listener: Option<Listener>,
20    service: Service,
21}
22
23impl<Listener, Service> Server<Listener, Service>
24where
25    Listener: listener::Listener,
26    Service: service::Service<Listener::Socket>,
27{
28    /// Create a new server that serves `service` to incoming connections from `listener`.
29    pub fn new(listener: Listener, service: Service) -> Self {
30        Self {
31            listener: Some(listener),
32            service,
33        }
34    }
35
36    /// Run the server.
37    ///
38    /// # Caveats
39    ///
40    /// Due to [a bug in the rust compiler][abrc], the future returned by this method can not be
41    /// treated as `Send`, even if all the specific types involved are `Send`. A major consequence
42    /// of this fact unfortunately, is that it can not be spawned in a task of a multi-threaded
43    /// runtime. For example, you can not currently do `tokio::spawn(server.run())`.
44    ///
45    /// Fortunately, there are easy workarounds for this. You can either:
46    ///
47    /// * Use a thread-local runtime (for example [`tokio::runtime::LocalRuntime`] or
48    ///   [`tokio::task::LocalSet`]) to run the server in a local task, perhaps in a separate
49    ///   thread.
50    /// * Use some common API to run multiple futures at once, such as [`futures::select!`] or
51    ///   [`tokio::select!`].
52    ///
53    /// Most importantly, this is most likely a temporary issue and will be fixed in the future. 😊
54    ///
55    /// [abrc]: https://github.com/rust-lang/rust/issues/100013
56    /// [`tokio::runtime::LocalRuntime`]: https://docs.rs/tokio/latest/tokio/runtime/struct.LocalRuntime.html
57    /// [`tokio::task::LocalSet`]: https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html
58    /// [`futures::select!`]: https://docs.rs/futures/latest/futures/macro.select.html
59    /// [`tokio::select!`]: https://docs.rs/tokio/latest/tokio/macro.select.html
60    pub async fn run(mut self) -> crate::Result<()> {
61        let mut listener = self.listener.take().unwrap();
62        let mut connections = Vec::new();
63        let mut reply_streams = Vec::<ReplyStream<Service::ReplyStream, Listener::Socket>>::new();
64        let mut reply_stream_futures = Vec::new();
65        // Vec for futures from `Connection::receive_call`. Reused across iterations to avoid
66        // per-iteration allocations.
67        let mut read_futures = Vec::new();
68        let mut last_reply_stream_winner = None;
69        let mut last_method_call_winner = None;
70
71        loop {
72            // We re-populate the `reply_stream_futures` in each iteration so we must clear it
73            // first.
74            reply_stream_futures.clear();
75            {
76                // SAFETY: Rust has no way to know that we don't re-use the mutable references in
77                // each iteration (since we clear the `reply_stream_futures` vector) so we need to
78                // go through a pointer to work around this.
79                let reply_streams: &mut Vec<ReplyStream<Service::ReplyStream, Listener::Socket>> =
80                    unsafe { &mut *(&mut reply_streams as *mut Vec<_>) };
81                reply_stream_futures.extend(reply_streams.iter_mut().map(|s| s.stream.next()));
82            }
83            let start_index = last_reply_stream_winner.map(|idx| idx + 1);
84            let mut reply_stream_select_all = SelectAll::new(start_index);
85            for future in reply_stream_futures.iter_mut() {
86                reply_stream_select_all.push(future);
87            }
88
89            // Prepare futures for reading method calls from connections.
90            read_futures.clear();
91            {
92                // SAFETY: Same as above - mutable references are not reused across iterations.
93                let connections: &mut Vec<Connection<Listener::Socket>> =
94                    unsafe { &mut *(&mut connections as *mut Vec<_>) };
95                read_futures.extend(connections.iter_mut().map(|c| c.receive_call()));
96            }
97            let mut read_select_all = SelectAll::new(last_method_call_winner.map(|idx| idx + 1));
98            for future in &mut read_futures {
99                // SAFETY: Futures in `read_futures` are dropped in place via `clear()` at the
100                // start of the next iteration, never moved while pinned.
101                unsafe {
102                    read_select_all.push_unchecked(future);
103                }
104            }
105
106            futures_util::select_biased! {
107                // 1. Accept a new connection.
108                conn = listener.accept().fuse() => {
109                    connections.push(conn?);
110                }
111                // 2. Read method calls from the existing connections and handle them.
112                (idx, result) = read_select_all.fuse() => {
113                        #[cfg(feature = "std")]
114                        let (call, fds) = match result {
115                            Ok((call, fds)) => (Ok(call), fds),
116                            Err(e) => (Err(e), alloc::vec![]),
117                        };
118                        #[cfg(not(feature = "std"))]
119                        let call = result;
120                        last_method_call_winner = Some(idx);
121
122                        let mut stream = None;
123                        let mut remove = true;
124                        match call {
125                            Ok(call) => {
126                                #[cfg(feature = "std")]
127                                let result =
128                                    self.handle_call(call, &mut connections[idx], fds).await;
129                                #[cfg(not(feature = "std"))]
130                                let result =
131                                    self.handle_call(call, &mut connections[idx]).await;
132                                match result {
133                                    Ok(None) => remove = false,
134                                    Ok(Some(s)) => stream = Some(s),
135                                    Err(e) => warn!("Error writing to connection: {:?}", e),
136                                }
137                            }
138                            Err(e) => debug!("Error reading from socket: {:?}", e),
139                        }
140
141                        if stream.is_some() || remove {
142                            let conn = connections.swap_remove(idx);
143
144                            if let Some(stream) = stream {
145                                reply_streams.push(ReplyStream::new(stream, conn));
146                            }
147                        }
148                }
149                // 3. Read replies from the reply streams and send them off.
150                reply = reply_stream_select_all.fuse() => {
151                    let (idx, item) = reply;
152                    last_reply_stream_winner = Some(idx);
153                    let id = reply_streams[idx].conn.id();
154
155                    match item {
156                        Some(item) => {
157                            #[cfg(feature = "std")]
158                            let (item, fds) = item;
159                            #[cfg(feature = "std")]
160                            let send_result = match item {
161                                Ok(reply) => reply_streams[idx]
162                                    .conn
163                                    .send_reply(&reply, fds)
164                                    .await,
165                                Err(error) => reply_streams[idx]
166                                    .conn
167                                    .send_error(&error, fds)
168                                    .await,
169                            };
170                            #[cfg(not(feature = "std"))]
171                            let send_result = match item {
172                                Ok(reply) => reply_streams[idx].conn.send_reply(&reply).await,
173                                Err(error) => reply_streams[idx].conn.send_error(&error).await,
174                            };
175                            if let Err(e) = send_result {
176                                warn!("Error writing to client {}: {:?}", id, e);
177                                reply_streams.swap_remove(idx);
178                            }
179                        }
180                        None => {
181                            trace!("Stream closed for client {}", id);
182                            let stream = reply_streams.swap_remove(idx);
183                            connections.push(stream.conn);
184                        }
185                    }
186                }
187            }
188        }
189    }
190
191    async fn handle_call(
192        &mut self,
193        call: Call<Service::MethodCall<'_>>,
194        conn: &mut Connection<Listener::Socket>,
195        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
196    ) -> crate::Result<Option<Service::ReplyStream>> {
197        let mut stream = None;
198
199        #[cfg(feature = "std")]
200        let (reply, reply_fds) = self.service.handle(&call, conn, fds).await;
201        #[cfg(not(feature = "std"))]
202        let reply = self.service.handle(&call, conn).await;
203
204        match reply {
205            // Don't send replies or errors for oneway calls.
206            MethodReply::Single(_) | MethodReply::Error(_) if call.oneway() => (),
207            MethodReply::Single(params) => {
208                let reply = Reply::new(params).set_continues(Some(false));
209                #[cfg(feature = "std")]
210                conn.send_reply(&reply, reply_fds).await?;
211                #[cfg(not(feature = "std"))]
212                conn.send_reply(&reply).await?;
213            }
214            #[cfg(feature = "std")]
215            MethodReply::Error(err) => conn.send_error(&err, reply_fds).await?,
216            #[cfg(not(feature = "std"))]
217            MethodReply::Error(err) => conn.send_error(&err).await?,
218            MethodReply::Multi(s) => {
219                trace!("Client {} now turning into a reply stream", conn.id());
220                stream = Some(s)
221            }
222        }
223
224        Ok(stream)
225    }
226}
227
228/// Method reply stream and connection pair.
229#[derive(Debug)]
230struct ReplyStream<St, Sock: Socket> {
231    stream: St,
232    conn: Connection<Sock>,
233}
234
235impl<St, Sock> ReplyStream<St, Sock>
236where
237    Sock: Socket,
238{
239    fn new(stream: St, conn: Connection<Sock>) -> Self {
240        Self { stream, conn }
241    }
242}