sim_lib_server/transport/
socket.rs1use std::{
2 fs,
3 io::ErrorKind,
4 net::{Shutdown, TcpListener, TcpStream},
5 path::Path,
6 sync::Arc,
7 time::Duration,
8};
9
10#[cfg(unix)]
11use std::os::unix::{
12 fs::FileTypeExt,
13 net::{UnixListener, UnixStream},
14};
15
16use sim_kernel::{Cx, Error, Result, Symbol};
17
18use crate::{EvalSite, FrameKind, ServerAddress, ServerFrame, ServerRuntime};
19
20use super::{
21 ConnectionTransport, SERVER_CONNECTION_IO_TIMEOUT_MS, ServerTransport, answer_or_negotiate,
22 error_frame_from_error, io_to_host, is_timeout, read_frame_from,
23 update_negotiated_codec_from_reply, write_frame_to,
24};
25
26pub struct TcpServerTransport {
28 address: ServerAddress,
29 listener: TcpListener,
30}
31
32impl TcpServerTransport {
33 pub fn bind(address: ServerAddress) -> Result<Self> {
35 let ServerAddress::Tcp { host, port } = &address else {
36 return Err(Error::Eval(
37 "tcp transport requires a tcp address".to_owned(),
38 ));
39 };
40 let listener = TcpListener::bind((host.as_str(), *port)).map_err(io_to_host)?;
41 listener.set_nonblocking(true).map_err(io_to_host)?;
42 let local_addr = listener.local_addr().map_err(io_to_host)?;
43 Ok(Self {
44 address: ServerAddress::Tcp {
45 host: host.clone(),
46 port: local_addr.port(),
47 },
48 listener,
49 })
50 }
51
52 #[cfg_attr(not(test), allow(dead_code))]
53 pub fn local_port(&self) -> Result<u16> {
55 Ok(self.listener.local_addr().map_err(io_to_host)?.port())
56 }
57}
58
59impl ServerTransport for TcpServerTransport {
60 fn address(&self) -> &ServerAddress {
61 &self.address
62 }
63
64 fn accept(&self, cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
65 loop {
66 if let Some(connection) = self.accept_timeout(cx, Duration::from_millis(25))? {
67 return Ok(connection);
68 }
69 }
70 }
71
72 fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
73 Ok(())
74 }
75
76 fn accept_timeout(
77 &self,
78 _cx: &mut Cx,
79 _timeout: Duration,
80 ) -> Result<Option<Box<dyn ConnectionTransport>>> {
81 match self.listener.accept() {
82 Ok((stream, _peer)) => {
83 stream.set_nodelay(true).map_err(io_to_host)?;
84 Ok(Some(Box::new(TcpConnectionTransport::server_side(stream))))
85 }
86 Err(error) if error.kind() == ErrorKind::WouldBlock => Ok(None),
87 Err(error) => Err(io_to_host(error)),
88 }
89 }
90}
91
92pub struct TcpConnectionTransport {
93 stream: TcpStream,
94}
95
96impl TcpConnectionTransport {
97 pub fn connect(address: &ServerAddress) -> Result<Self> {
98 let ServerAddress::Tcp { host, port } = address else {
99 return Err(Error::Eval("tcp connect requires a tcp address".to_owned()));
100 };
101 let stream = TcpStream::connect((host.as_str(), *port)).map_err(io_to_host)?;
102 stream.set_nodelay(true).map_err(io_to_host)?;
103 Ok(Self { stream })
104 }
105
106 fn server_side(stream: TcpStream) -> Self {
107 Self { stream }
108 }
109
110 fn serve(&mut self, runtime: &Arc<ServerRuntime>, site: &Arc<dyn EvalSite>) -> Result<()> {
111 let session_id = runtime.open_session(
112 Symbol::qualified("codec", "binary"),
113 runtime.session_isolation().clone(),
114 )?;
115 let mut inflight = 0usize;
116 loop {
117 if runtime.is_stopping() {
118 let _ = runtime.close_session(session_id);
119 return Ok(());
120 }
121
122 let frame = match self.recv_frame_for_serve() {
123 Ok(Some(frame)) => frame,
124 Ok(None) => continue,
125 Err(error) => {
126 let _ = runtime.close_session(session_id);
127 return Err(error);
128 }
129 };
130 let Some(frame) = frame else {
131 let _ = runtime.close_session(session_id);
132 return Ok(());
133 };
134 runtime.note_message_received();
135 if runtime.is_stopping() {
136 let _ = runtime.close_session(session_id);
137 return Ok(());
138 }
139 if matches!(frame.kind, FrameKind::Request | FrameKind::Notify)
140 && inflight >= runtime.max_inflight()
141 {
142 let reply = runtime.with_cx(|cx| {
143 error_frame_from_error(
144 cx,
145 &frame,
146 &Error::Eval(format!(
147 "connection max-inflight {} exceeded",
148 runtime.max_inflight()
149 )),
150 )
151 })?;
152 write_frame_to(&mut self.stream, &reply)?;
153 runtime.note_message_sent();
154 continue;
155 }
156 if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
157 inflight = inflight.saturating_add(1);
158 }
159 let reply = match runtime.with_cx(|cx| answer_or_negotiate(cx, site, frame.clone())) {
160 Ok(reply) => {
161 update_negotiated_codec_from_reply(runtime, session_id, &frame, &reply)?;
162 reply
163 }
164 Err(error) => runtime.with_cx(|cx| error_frame_from_error(cx, &frame, &error))?,
165 };
166 if runtime.is_stopping() {
167 let _ = runtime.close_session(session_id);
168 return Ok(());
169 }
170 write_frame_to(&mut self.stream, &reply)?;
171 runtime.note_message_sent();
172 if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
173 inflight = inflight.saturating_sub(1);
174 }
175 }
176 }
177
178 fn recv_frame_for_serve(&mut self) -> Result<Option<Option<ServerFrame>>> {
179 self.stream
180 .set_read_timeout(Some(Duration::from_millis(SERVER_CONNECTION_IO_TIMEOUT_MS)))
181 .map_err(io_to_host)?;
182 match read_frame_from(&mut self.stream) {
183 Ok(frame) => Ok(Some(frame)),
184 Err(error) if is_timeout(&error) => Ok(None),
185 Err(error) => Err(error),
186 }
187 }
188}
189
190impl ConnectionTransport for TcpConnectionTransport {
191 fn send_frame(&mut self, _cx: &mut Cx, frame: ServerFrame) -> Result<()> {
192 write_frame_to(&mut self.stream, &frame)
193 }
194
195 fn recv_frame(
196 &mut self,
197 _cx: &mut Cx,
198 timeout: Option<Duration>,
199 ) -> Result<Option<ServerFrame>> {
200 self.stream.set_read_timeout(timeout).map_err(io_to_host)?;
201 match read_frame_from(&mut self.stream) {
202 Ok(frame) => Ok(frame),
203 Err(error) if is_timeout(&error) => Ok(None),
204 Err(error) => Err(error),
205 }
206 }
207
208 fn close(&mut self, _cx: &mut Cx) -> Result<()> {
209 let _ = self.stream.shutdown(Shutdown::Both);
210 Ok(())
211 }
212
213 fn as_any(&self) -> &dyn std::any::Any {
214 self
215 }
216
217 fn serve_connection(
218 &mut self,
219 runtime: &Arc<ServerRuntime>,
220 site: &Arc<dyn EvalSite>,
221 ) -> Result<()> {
222 self.serve(runtime, site)
223 }
224}
225
226#[cfg(unix)]
227pub struct UnixServerTransport {
228 address: ServerAddress,
229 listener: UnixListener,
230}
231
232#[cfg(unix)]
233impl UnixServerTransport {
234 pub fn bind(address: ServerAddress) -> Result<Self> {
235 let ServerAddress::Unix { path } = &address else {
236 return Err(Error::Eval(
237 "unix transport requires a unix address".to_owned(),
238 ));
239 };
240 remove_stale_unix_socket(path)?;
241 let listener = UnixListener::bind(path).map_err(io_to_host)?;
242 listener.set_nonblocking(true).map_err(io_to_host)?;
243 Ok(Self { address, listener })
244 }
245}
246
247#[cfg(unix)]
248impl ServerTransport for UnixServerTransport {
249 fn address(&self) -> &ServerAddress {
250 &self.address
251 }
252
253 fn accept(&self, cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
254 loop {
255 if let Some(connection) = self.accept_timeout(cx, Duration::from_millis(25))? {
256 return Ok(connection);
257 }
258 }
259 }
260
261 fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
262 let ServerAddress::Unix { path } = &self.address else {
263 return Ok(());
264 };
265 remove_bound_unix_socket(path)
266 }
267
268 fn accept_timeout(
269 &self,
270 _cx: &mut Cx,
271 _timeout: Duration,
272 ) -> Result<Option<Box<dyn ConnectionTransport>>> {
273 match self.listener.accept() {
274 Ok((stream, _peer)) => Ok(Some(Box::new(UnixConnectionTransport::server_side(stream)))),
275 Err(error) if error.kind() == ErrorKind::WouldBlock => Ok(None),
276 Err(error) => Err(io_to_host(error)),
277 }
278 }
279}
280
281#[cfg(unix)]
282pub struct UnixConnectionTransport {
283 stream: UnixStream,
284}
285
286#[cfg(unix)]
287impl UnixConnectionTransport {
288 pub fn connect(address: &ServerAddress) -> Result<Self> {
289 let ServerAddress::Unix { path } = address else {
290 return Err(Error::Eval(
291 "unix connect requires a unix address".to_owned(),
292 ));
293 };
294 let stream = UnixStream::connect(path).map_err(io_to_host)?;
295 Ok(Self { stream })
296 }
297
298 fn server_side(stream: UnixStream) -> Self {
299 Self { stream }
300 }
301
302 fn serve(&mut self, runtime: &Arc<ServerRuntime>, site: &Arc<dyn EvalSite>) -> Result<()> {
303 let session_id = runtime.open_session(
304 Symbol::qualified("codec", "binary"),
305 runtime.session_isolation().clone(),
306 )?;
307 let mut inflight = 0usize;
308 loop {
309 if runtime.is_stopping() {
310 let _ = runtime.close_session(session_id);
311 return Ok(());
312 }
313
314 let frame = match self.recv_frame_for_serve() {
315 Ok(Some(frame)) => frame,
316 Ok(None) => continue,
317 Err(error) => {
318 let _ = runtime.close_session(session_id);
319 return Err(error);
320 }
321 };
322 let Some(frame) = frame else {
323 let _ = runtime.close_session(session_id);
324 return Ok(());
325 };
326 runtime.note_message_received();
327 if runtime.is_stopping() {
328 let _ = runtime.close_session(session_id);
329 return Ok(());
330 }
331 if matches!(frame.kind, FrameKind::Request | FrameKind::Notify)
332 && inflight >= runtime.max_inflight()
333 {
334 let reply = runtime.with_cx(|cx| {
335 error_frame_from_error(
336 cx,
337 &frame,
338 &Error::Eval(format!(
339 "connection max-inflight {} exceeded",
340 runtime.max_inflight()
341 )),
342 )
343 })?;
344 write_frame_to(&mut self.stream, &reply)?;
345 runtime.note_message_sent();
346 continue;
347 }
348 if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
349 inflight = inflight.saturating_add(1);
350 }
351 let reply = match runtime.with_cx(|cx| answer_or_negotiate(cx, site, frame.clone())) {
352 Ok(reply) => {
353 update_negotiated_codec_from_reply(runtime, session_id, &frame, &reply)?;
354 reply
355 }
356 Err(error) => runtime.with_cx(|cx| error_frame_from_error(cx, &frame, &error))?,
357 };
358 if runtime.is_stopping() {
359 let _ = runtime.close_session(session_id);
360 return Ok(());
361 }
362 write_frame_to(&mut self.stream, &reply)?;
363 runtime.note_message_sent();
364 if matches!(frame.kind, FrameKind::Request | FrameKind::Notify) {
365 inflight = inflight.saturating_sub(1);
366 }
367 }
368 }
369
370 fn recv_frame_for_serve(&mut self) -> Result<Option<Option<ServerFrame>>> {
371 self.stream
372 .set_read_timeout(Some(Duration::from_millis(SERVER_CONNECTION_IO_TIMEOUT_MS)))
373 .map_err(io_to_host)?;
374 match read_frame_from(&mut self.stream) {
375 Ok(frame) => Ok(Some(frame)),
376 Err(error) if is_timeout(&error) => Ok(None),
377 Err(error) => Err(error),
378 }
379 }
380}
381
382#[cfg(unix)]
383impl ConnectionTransport for UnixConnectionTransport {
384 fn send_frame(&mut self, _cx: &mut Cx, frame: ServerFrame) -> Result<()> {
385 write_frame_to(&mut self.stream, &frame)
386 }
387
388 fn recv_frame(
389 &mut self,
390 _cx: &mut Cx,
391 timeout: Option<Duration>,
392 ) -> Result<Option<ServerFrame>> {
393 self.stream.set_read_timeout(timeout).map_err(io_to_host)?;
394 match read_frame_from(&mut self.stream) {
395 Ok(frame) => Ok(frame),
396 Err(error) if is_timeout(&error) => Ok(None),
397 Err(error) => Err(error),
398 }
399 }
400
401 fn close(&mut self, _cx: &mut Cx) -> Result<()> {
402 Ok(())
403 }
404
405 fn as_any(&self) -> &dyn std::any::Any {
406 self
407 }
408
409 fn serve_connection(
410 &mut self,
411 runtime: &Arc<ServerRuntime>,
412 site: &Arc<dyn EvalSite>,
413 ) -> Result<()> {
414 self.serve(runtime, site)
415 }
416}
417
418#[cfg(unix)]
419fn remove_stale_unix_socket(path: &Path) -> Result<()> {
420 match fs::symlink_metadata(path) {
421 Ok(metadata) if metadata.file_type().is_socket() => {
422 fs::remove_file(path).map_err(io_to_host)?;
423 Ok(())
424 }
425 Ok(_) => Ok(()),
426 Err(error) if error.kind() == ErrorKind::NotFound => Ok(()),
427 Err(error) => Err(io_to_host(error)),
428 }
429}
430
431#[cfg(unix)]
432fn remove_bound_unix_socket(path: &Path) -> Result<()> {
433 match fs::symlink_metadata(path) {
434 Ok(metadata) if metadata.file_type().is_socket() => {
435 fs::remove_file(path).map_err(io_to_host)?;
436 Ok(())
437 }
438 Ok(_) => Ok(()),
439 Err(error) if error.kind() == ErrorKind::NotFound => Ok(()),
440 Err(error) => Err(io_to_host(error)),
441 }
442}