zlink_core/server/
mod.rs

1pub(crate) mod listener;
2mod select_all;
3pub mod service;
4
5use alloc::vec::Vec;
6use futures_util::{FutureExt, StreamExt};
7use select_all::SelectAll;
8use service::MethodReply;
9
10use crate::{connection::Socket, Call, Connection, Reply};
11
12/// A server.
13///
14/// The server listens for incoming connections and handles method calls using a service.
15#[derive(Debug)]
16pub struct Server<Listener, Service> {
17    listener: Option<Listener>,
18    service: Service,
19}
20
21impl<Listener, Service> Server<Listener, Service>
22where
23    Listener: listener::Listener,
24    Service: service::Service,
25{
26    /// Create a new server that serves `service` to incoming connections from `listener`.
27    pub fn new(listener: Listener, service: Service) -> Self {
28        Self {
29            listener: Some(listener),
30            service,
31        }
32    }
33
34    /// Run the server.
35    ///
36    /// # Caveats
37    ///
38    /// Due to [a bug in the rust compiler][abrc], the future returned by this method can not be
39    /// treated as `Send`, even if all the specific types involved are `Send`. A major consequence
40    /// of this fact unfortunately, is that it can not be spawned in a task of a multi-threaded
41    /// runtime. For example, you can not currently do `tokio::spawn(server.run())`.
42    ///
43    /// Fortunately, there are easy workarounds for this. You can either:
44    ///
45    /// * Use a thread-local runtime (for example [`tokio::runtime::LocalRuntime`] or
46    ///   [`tokio::task::LocalSet`]) to run the server in a local task, perhaps in a separate
47    ///   thread.
48    /// * Use some common API to run multiple futures at once, such as [`futures::select!`] or
49    ///   [`tokio::select!`].
50    ///
51    /// Most importantly, this is most likely a temporary issue and will be fixed in the future. 😊
52    ///
53    /// [abrc]: https://github.com/rust-lang/rust/issues/100013
54    /// [`tokio::runtime::LocalRuntime`]: https://docs.rs/tokio/latest/tokio/runtime/struct.LocalRuntime.html
55    /// [`tokio::task::LocalSet`]: https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html
56    /// [`futures::select!`]: https://docs.rs/futures/latest/futures/macro.select.html
57    /// [`tokio::select!`]: https://docs.rs/tokio/latest/tokio/macro.select.html
58    pub async fn run(mut self) -> crate::Result<()> {
59        let mut listener = self.listener.take().unwrap();
60        let mut connections = Vec::new();
61        let mut reply_streams = Vec::<ReplyStream<Service::ReplyStream, Listener::Socket>>::new();
62        let mut reply_stream_futures = Vec::new();
63        // Vec for futures from `Connection::receive_call`. Reused across iterations to avoid
64        // per-iteration allocations.
65        let mut read_futures = Vec::new();
66        let mut last_reply_stream_winner = None;
67        let mut last_method_call_winner = None;
68
69        loop {
70            // We re-populate the `reply_stream_futures` in each iteration so we must clear it
71            // first.
72            reply_stream_futures.clear();
73            {
74                // SAFETY: Rust has no way to know that we don't re-use the mutable references in
75                // each iteration (since we clear the `reply_stream_futures` vector) so we need to
76                // go through a pointer to work around this.
77                let reply_streams: &mut Vec<ReplyStream<Service::ReplyStream, Listener::Socket>> =
78                    unsafe { &mut *(&mut reply_streams as *mut Vec<_>) };
79                reply_stream_futures.extend(reply_streams.iter_mut().map(|s| s.stream.next()));
80            }
81            let start_index = last_reply_stream_winner.map(|idx| idx + 1);
82            let mut reply_stream_select_all = SelectAll::new(start_index);
83            for future in reply_stream_futures.iter_mut() {
84                reply_stream_select_all.push(future);
85            }
86
87            // Prepare futures for reading method calls from connections.
88            read_futures.clear();
89            {
90                // SAFETY: Same as above - mutable references are not reused across iterations.
91                let connections: &mut Vec<Connection<Listener::Socket>> =
92                    unsafe { &mut *(&mut connections as *mut Vec<_>) };
93                read_futures.extend(connections.iter_mut().map(|c| c.receive_call()));
94            }
95            let mut read_select_all = SelectAll::new(last_method_call_winner.map(|idx| idx + 1));
96            for future in &mut read_futures {
97                // SAFETY: Futures in `read_futures` are dropped in place via `clear()` at the
98                // start of the next iteration, never moved while pinned.
99                unsafe {
100                    read_select_all.push_unchecked(future);
101                }
102            }
103
104            futures_util::select_biased! {
105                // 1. Accept a new connection.
106                conn = listener.accept().fuse() => {
107                    connections.push(conn?);
108                }
109                // 2. Read method calls from the existing connections and handle them.
110                (idx, result) = read_select_all.fuse() => {
111                        #[cfg(feature = "std")]
112                        let call = result.map(|(call, _fds)| call);
113                        #[cfg(not(feature = "std"))]
114                        let call = result;
115                        last_method_call_winner = Some(idx);
116
117                        let mut stream = None;
118                        let mut remove = true;
119                        match call {
120                            Ok(call) => {
121                                match self.handle_call(call, &mut connections[idx]).await {
122                                    Ok(None) => remove = false,
123                                    Ok(Some(s)) => stream = Some(s),
124                                    Err(e) => warn!("Error writing to connection: {:?}", e),
125                                }
126                            }
127                            Err(e) => warn!("Error reading from socket: {:?}", e),
128                        }
129
130                        if stream.is_some() || remove {
131                            let conn = connections.swap_remove(idx);
132
133                            if let Some(stream) = stream {
134                                reply_streams.push(ReplyStream::new(stream, conn));
135                            }
136                        }
137                }
138                // 3. Read replies from the reply streams and send them off.
139                reply = reply_stream_select_all.fuse() => {
140                    let (idx, reply) = reply;
141                    last_reply_stream_winner = Some(idx);
142                    let id = reply_streams[idx].conn.id();
143
144                    match reply {
145                        Some(reply) => {
146                            #[cfg(feature = "std")]
147                            let send_result =
148                                reply_streams[idx].conn.send_reply(&reply, alloc::vec![]).await;
149                            #[cfg(not(feature = "std"))]
150                            let send_result = reply_streams[idx].conn.send_reply(&reply).await;
151                            if let Err(e) = send_result {
152                                warn!("Error writing to client {}: {:?}", id, e);
153                                reply_streams.swap_remove(idx);
154                            }
155                        }
156                        None => {
157                            trace!("Stream closed for client {}", id);
158                            let stream = reply_streams.swap_remove(idx);
159                            connections.push(stream.conn);
160                        }
161                    }
162                }
163            }
164        }
165    }
166
167    async fn handle_call(
168        &mut self,
169        call: Call<Service::MethodCall<'_>>,
170        conn: &mut Connection<Listener::Socket>,
171    ) -> crate::Result<Option<Service::ReplyStream>> {
172        let mut stream = None;
173        match self.service.handle(&call, conn).await {
174            // Don't send replies or errors for oneway calls.
175            MethodReply::Single(_) | MethodReply::Error(_) if call.oneway() => (),
176            MethodReply::Single(params) => {
177                let reply = Reply::new(params).set_continues(Some(false));
178                #[cfg(feature = "std")]
179                conn.send_reply(&reply, alloc::vec![]).await?;
180                #[cfg(not(feature = "std"))]
181                conn.send_reply(&reply).await?;
182            }
183            #[cfg(feature = "std")]
184            MethodReply::Error(err) => conn.send_error(&err, alloc::vec![]).await?,
185            #[cfg(not(feature = "std"))]
186            MethodReply::Error(err) => conn.send_error(&err).await?,
187            MethodReply::Multi(s) => {
188                trace!("Client {} now turning into a reply stream", conn.id());
189                stream = Some(s)
190            }
191        }
192
193        Ok(stream)
194    }
195}
196
197/// Method reply stream and connection pair.
198#[derive(Debug)]
199struct ReplyStream<St, Sock: Socket> {
200    stream: St,
201    conn: Connection<Sock>,
202}
203
204impl<St, Sock> ReplyStream<St, Sock>
205where
206    Sock: Socket,
207{
208    fn new(stream: St, conn: Connection<Sock>) -> Self {
209        Self { stream, conn }
210    }
211}