Skip to main content

sim_lib_server/transport/
socket.rs

1use std::{
2    fs,
3    io::ErrorKind,
4    net::{Shutdown, TcpListener, TcpStream},
5    path::Path,
6    sync::Arc,
7    time::Duration,
8};
9
10#[cfg(unix)]
11use std::os::unix::{
12    fs::FileTypeExt,
13    net::{UnixListener, UnixStream},
14};
15
16use sim_kernel::{Cx, Error, Result, Symbol};
17
18use crate::{EvalSite, FrameKind, ServerAddress, ServerFrame, ServerRuntime};
19
20use super::{
21    ConnectionTransport, SERVER_CONNECTION_IO_TIMEOUT_MS, ServerTransport, answer_or_negotiate,
22    error_frame_from_error, io_to_host, is_timeout, read_frame_from,
23    update_negotiated_codec_from_reply, write_frame_to,
24};
25
26/// TCP listener transport for server-frame connections.
27pub struct TcpServerTransport {
28    address: ServerAddress,
29    listener: TcpListener,
30}
31
32impl TcpServerTransport {
33    /// Binds a TCP listener to `address`.
34    pub fn bind(address: ServerAddress) -> Result<Self> {
35        let ServerAddress::Tcp { host, port } = &address else {
36            return Err(Error::Eval(
37                "tcp transport requires a tcp address".to_owned(),
38            ));
39        };
40        let listener = TcpListener::bind((host.as_str(), *port)).map_err(io_to_host)?;
41        listener.set_nonblocking(true).map_err(io_to_host)?;
42        let local_addr = listener.local_addr().map_err(io_to_host)?;
43        Ok(Self {
44            address: ServerAddress::Tcp {
45                host: host.clone(),
46                port: local_addr.port(),
47            },
48            listener,
49        })
50    }
51
52    #[cfg_attr(not(test), allow(dead_code))]
53    /// Returns the bound local port.
54    pub fn local_port(&self) -> Result<u16> {
55        Ok(self.listener.local_addr().map_err(io_to_host)?.port())
56    }
57}
58
59impl ServerTransport for TcpServerTransport {
60    fn address(&self) -> &ServerAddress {
61        &self.address
62    }
63
64    fn accept(&self, cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
65        loop {
66            if let Some(connection) = self.accept_timeout(cx, Duration::from_millis(25))? {
67                return Ok(connection);
68            }
69        }
70    }
71
72    fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
73        Ok(())
74    }
75
76    fn accept_timeout(
77        &self,
78        _cx: &mut Cx,
79        _timeout: Duration,
80    ) -> Result<Option<Box<dyn ConnectionTransport>>> {
81        match self.listener.accept() {
82            Ok((stream, _peer)) => {
83                stream.set_nodelay(true).map_err(io_to_host)?;
84                Ok(Some(Box::new(TcpConnectionTransport::server_side(stream))))
85            }
86            Err(error) if error.kind() == ErrorKind::WouldBlock => Ok(None),
87            Err(error) => Err(io_to_host(error)),
88        }
89    }
90}
91
92pub struct TcpConnectionTransport {
93    stream: TcpStream,
94}
95
96impl TcpConnectionTransport {
97    pub fn connect(address: &ServerAddress) -> Result<Self> {
98        let ServerAddress::Tcp { host, port } = address else {
99            return Err(Error::Eval("tcp connect requires a tcp address".to_owned()));
100        };
101        let stream = TcpStream::connect((host.as_str(), *port)).map_err(io_to_host)?;
102        stream.set_nodelay(true).map_err(io_to_host)?;
103        Ok(Self { stream })
104    }
105
106    fn server_side(stream: TcpStream) -> Self {
107        Self { stream }
108    }
109
110    fn serve(&mut self, runtime: &Arc<ServerRuntime>, site: &Arc<dyn EvalSite>) -> Result<()> {
111        let session_id = runtime.open_session(
112            Symbol::qualified("codec", "binary"),
113            runtime.session_isolation().clone(),
114        )?;
115        let mut inflight = 0usize;
116        loop {
117            if runtime.is_stopping() {
118                let _ = runtime.close_session(session_id);
119                return Ok(());
120            }
121
122            let frame = match self.recv_frame_for_serve() {
123                Ok(Some(frame)) => frame,
124                Ok(None) => continue,
125                Err(error) => {
126                    let _ = runtime.close_session(session_id);
127                    return Err(error);
128                }
129            };
130            let Some(frame) = frame else {
131                let _ = runtime.close_session(session_id);
132                return Ok(());
133            };
134            runtime.note_message_received();
135            if runtime.is_stopping() {
136                let _ = runtime.close_session(session_id);
137                return Ok(());
138            }
139            if matches!(frame.kind, FrameKind::Request | FrameKind::Notify)
140                && inflight >= runtime.max_inflight()
141            {
142                let reply = runtime.with_cx(|cx| {
143                    error_frame_from_error(
144                        cx,
145                        &frame,
146                        &Error::Eval(format!(
147                            "connection max-inflight {} exceeded",
148                            runtime.max_inflight()
149                        )),
150                    )
151                })?;
152                write_frame_to(&mut self.stream, &reply)?;
153                runtime.note_message_sent();
154                continue;
155            }
156            if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
157                inflight = inflight.saturating_add(1);
158            }
159            let reply = match runtime.with_cx(|cx| answer_or_negotiate(cx, site, frame.clone())) {
160                Ok(reply) => {
161                    update_negotiated_codec_from_reply(runtime, session_id, &frame, &reply)?;
162                    reply
163                }
164                Err(error) => runtime.with_cx(|cx| error_frame_from_error(cx, &frame, &error))?,
165            };
166            if runtime.is_stopping() {
167                let _ = runtime.close_session(session_id);
168                return Ok(());
169            }
170            write_frame_to(&mut self.stream, &reply)?;
171            runtime.note_message_sent();
172            if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
173                inflight = inflight.saturating_sub(1);
174            }
175        }
176    }
177
178    fn recv_frame_for_serve(&mut self) -> Result<Option<Option<ServerFrame>>> {
179        self.stream
180            .set_read_timeout(Some(Duration::from_millis(SERVER_CONNECTION_IO_TIMEOUT_MS)))
181            .map_err(io_to_host)?;
182        match read_frame_from(&mut self.stream) {
183            Ok(frame) => Ok(Some(frame)),
184            Err(error) if is_timeout(&error) => Ok(None),
185            Err(error) => Err(error),
186        }
187    }
188}
189
190impl ConnectionTransport for TcpConnectionTransport {
191    fn send_frame(&mut self, _cx: &mut Cx, frame: ServerFrame) -> Result<()> {
192        write_frame_to(&mut self.stream, &frame)
193    }
194
195    fn recv_frame(
196        &mut self,
197        _cx: &mut Cx,
198        timeout: Option<Duration>,
199    ) -> Result<Option<ServerFrame>> {
200        self.stream.set_read_timeout(timeout).map_err(io_to_host)?;
201        match read_frame_from(&mut self.stream) {
202            Ok(frame) => Ok(frame),
203            Err(error) if is_timeout(&error) => Ok(None),
204            Err(error) => Err(error),
205        }
206    }
207
208    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
209        let _ = self.stream.shutdown(Shutdown::Both);
210        Ok(())
211    }
212
213    fn as_any(&self) -> &dyn std::any::Any {
214        self
215    }
216
217    fn serve_connection(
218        &mut self,
219        runtime: &Arc<ServerRuntime>,
220        site: &Arc<dyn EvalSite>,
221    ) -> Result<()> {
222        self.serve(runtime, site)
223    }
224}
225
226#[cfg(unix)]
227pub struct UnixServerTransport {
228    address: ServerAddress,
229    listener: UnixListener,
230}
231
232#[cfg(unix)]
233impl UnixServerTransport {
234    pub fn bind(address: ServerAddress) -> Result<Self> {
235        let ServerAddress::Unix { path } = &address else {
236            return Err(Error::Eval(
237                "unix transport requires a unix address".to_owned(),
238            ));
239        };
240        remove_stale_unix_socket(path)?;
241        let listener = UnixListener::bind(path).map_err(io_to_host)?;
242        listener.set_nonblocking(true).map_err(io_to_host)?;
243        Ok(Self { address, listener })
244    }
245}
246
247#[cfg(unix)]
248impl ServerTransport for UnixServerTransport {
249    fn address(&self) -> &ServerAddress {
250        &self.address
251    }
252
253    fn accept(&self, cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
254        loop {
255            if let Some(connection) = self.accept_timeout(cx, Duration::from_millis(25))? {
256                return Ok(connection);
257            }
258        }
259    }
260
261    fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
262        let ServerAddress::Unix { path } = &self.address else {
263            return Ok(());
264        };
265        remove_bound_unix_socket(path)
266    }
267
268    fn accept_timeout(
269        &self,
270        _cx: &mut Cx,
271        _timeout: Duration,
272    ) -> Result<Option<Box<dyn ConnectionTransport>>> {
273        match self.listener.accept() {
274            Ok((stream, _peer)) => Ok(Some(Box::new(UnixConnectionTransport::server_side(stream)))),
275            Err(error) if error.kind() == ErrorKind::WouldBlock => Ok(None),
276            Err(error) => Err(io_to_host(error)),
277        }
278    }
279}
280
281#[cfg(unix)]
282pub struct UnixConnectionTransport {
283    stream: UnixStream,
284}
285
286#[cfg(unix)]
287impl UnixConnectionTransport {
288    pub fn connect(address: &ServerAddress) -> Result<Self> {
289        let ServerAddress::Unix { path } = address else {
290            return Err(Error::Eval(
291                "unix connect requires a unix address".to_owned(),
292            ));
293        };
294        let stream = UnixStream::connect(path).map_err(io_to_host)?;
295        Ok(Self { stream })
296    }
297
298    fn server_side(stream: UnixStream) -> Self {
299        Self { stream }
300    }
301
302    fn serve(&mut self, runtime: &Arc<ServerRuntime>, site: &Arc<dyn EvalSite>) -> Result<()> {
303        let session_id = runtime.open_session(
304            Symbol::qualified("codec", "binary"),
305            runtime.session_isolation().clone(),
306        )?;
307        let mut inflight = 0usize;
308        loop {
309            if runtime.is_stopping() {
310                let _ = runtime.close_session(session_id);
311                return Ok(());
312            }
313
314            let frame = match self.recv_frame_for_serve() {
315                Ok(Some(frame)) => frame,
316                Ok(None) => continue,
317                Err(error) => {
318                    let _ = runtime.close_session(session_id);
319                    return Err(error);
320                }
321            };
322            let Some(frame) = frame else {
323                let _ = runtime.close_session(session_id);
324                return Ok(());
325            };
326            runtime.note_message_received();
327            if runtime.is_stopping() {
328                let _ = runtime.close_session(session_id);
329                return Ok(());
330            }
331            if matches!(frame.kind, FrameKind::Request | FrameKind::Notify)
332                && inflight >= runtime.max_inflight()
333            {
334                let reply = runtime.with_cx(|cx| {
335                    error_frame_from_error(
336                        cx,
337                        &frame,
338                        &Error::Eval(format!(
339                            "connection max-inflight {} exceeded",
340                            runtime.max_inflight()
341                        )),
342                    )
343                })?;
344                write_frame_to(&mut self.stream, &reply)?;
345                runtime.note_message_sent();
346                continue;
347            }
348            if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
349                inflight = inflight.saturating_add(1);
350            }
351            let reply = match runtime.with_cx(|cx| answer_or_negotiate(cx, site, frame.clone())) {
352                Ok(reply) => {
353                    update_negotiated_codec_from_reply(runtime, session_id, &frame, &reply)?;
354                    reply
355                }
356                Err(error) => runtime.with_cx(|cx| error_frame_from_error(cx, &frame, &error))?,
357            };
358            if runtime.is_stopping() {
359                let _ = runtime.close_session(session_id);
360                return Ok(());
361            }
362            write_frame_to(&mut self.stream, &reply)?;
363            runtime.note_message_sent();
364            if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
365                inflight = inflight.saturating_sub(1);
366            }
367        }
368    }
369
370    fn recv_frame_for_serve(&mut self) -> Result<Option<Option<ServerFrame>>> {
371        self.stream
372            .set_read_timeout(Some(Duration::from_millis(SERVER_CONNECTION_IO_TIMEOUT_MS)))
373            .map_err(io_to_host)?;
374        match read_frame_from(&mut self.stream) {
375            Ok(frame) => Ok(Some(frame)),
376            Err(error) if is_timeout(&error) => Ok(None),
377            Err(error) => Err(error),
378        }
379    }
380}
381
382#[cfg(unix)]
383impl ConnectionTransport for UnixConnectionTransport {
384    fn send_frame(&mut self, _cx: &mut Cx, frame: ServerFrame) -> Result<()> {
385        write_frame_to(&mut self.stream, &frame)
386    }
387
388    fn recv_frame(
389        &mut self,
390        _cx: &mut Cx,
391        timeout: Option<Duration>,
392    ) -> Result<Option<ServerFrame>> {
393        self.stream.set_read_timeout(timeout).map_err(io_to_host)?;
394        match read_frame_from(&mut self.stream) {
395            Ok(frame) => Ok(frame),
396            Err(error) if is_timeout(&error) => Ok(None),
397            Err(error) => Err(error),
398        }
399    }
400
401    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
402        Ok(())
403    }
404
405    fn as_any(&self) -> &dyn std::any::Any {
406        self
407    }
408
409    fn serve_connection(
410        &mut self,
411        runtime: &Arc<ServerRuntime>,
412        site: &Arc<dyn EvalSite>,
413    ) -> Result<()> {
414        self.serve(runtime, site)
415    }
416}
417
418#[cfg(unix)]
419fn remove_stale_unix_socket(path: &Path) -> Result<()> {
420    match fs::symlink_metadata(path) {
421        Ok(metadata) if metadata.file_type().is_socket() => {
422            fs::remove_file(path).map_err(io_to_host)?;
423            Ok(())
424        }
425        Ok(_) => Ok(()),
426        Err(error) if error.kind() == ErrorKind::NotFound => Ok(()),
427        Err(error) => Err(io_to_host(error)),
428    }
429}
430
431#[cfg(unix)]
432fn remove_bound_unix_socket(path: &Path) -> Result<()> {
433    match fs::symlink_metadata(path) {
434        Ok(metadata) if metadata.file_type().is_socket() => {
435            fs::remove_file(path).map_err(io_to_host)?;
436            Ok(())
437        }
438        Ok(_) => Ok(()),
439        Err(error) if error.kind() == ErrorKind::NotFound => Ok(()),
440        Err(error) => Err(io_to_host(error)),
441    }
442}