simple_socket/
server.rs

1use std::io;
2use std::marker::PhantomData;
3use std::net::SocketAddr;
4use std::time::Duration;
5
6use crate::backlog::Backlog;
7use crate::client::{SocketClient, SocketStatus};
8
9use bincode::Result;
10use serde::{de::DeserializeOwned, Serialize};
11use socket2::{Domain, Socket, Type};
12
13pub struct SocketServer<Req, Res>
14where
15    Req: DeserializeOwned,
16    Res: Serialize,
17{
18    streams: Vec<SocketClient<Res, Req>>,
19    listener: Socket,
20
21    _request: PhantomData<Req>,
22    _response: PhantomData<Res>,
23}
24
25impl<Req, Res> SocketServer<Req, Res>
26where
27    Req: DeserializeOwned,
28    Res: Serialize,
29{
30    pub fn try_new(addr: SocketAddr, backlog: Backlog) -> io::Result<Self> {
31        let domain = match addr {
32            SocketAddr::V4(_) => Domain::ipv4(),
33            SocketAddr::V6(_) => Domain::ipv6(),
34        };
35
36        let socket = Socket::new(domain, Type::stream(), None)?;
37        socket.bind(&addr.into())?;
38        socket.listen(backlog.into())?;
39
40        socket.set_nonblocking(true)?;
41
42        Ok(Self {
43            streams: vec![],
44            listener: socket,
45
46            _request: PhantomData::default(),
47            _response: PhantomData::default(),
48        })
49    }
50}
51
52impl<Req, Res> SocketServer<Req, Res>
53where
54    Req: DeserializeOwned,
55    Res: Serialize,
56{
57    pub fn run<H, P>(mut self, mut handler: H, post: P) -> Result<()>
58    where
59        H: FnMut(Req) -> Res,
60        P: Fn(&mut Self) -> PostServing,
61    {
62        loop {
63            if let Some(server_client) = self.accept()? {
64                self.streams.push(server_client);
65            }
66            for idx in (0..self.streams.len()).rev() {
67                let client = &mut self.streams[idx];
68
69                if let SocketStatus::Closed = client.response(|req| handler(req))? {
70                    self.streams.remove(idx);
71                }
72            }
73            match post(&mut self) {
74                PostServing::Wait(time) => std::thread::sleep(time),
75                PostServing::Yield => std::thread::yield_now(),
76                PostServing::Continue => continue,
77                PostServing::Stop => break Ok(()),
78            }
79        }
80    }
81
82    pub fn has_connections(&self) -> bool {
83        !self.streams.is_empty()
84    }
85
86    pub fn num_connections(&self) -> usize {
87        self.streams.len()
88    }
89
90    fn accept(&mut self) -> io::Result<Option<SocketClient<Res, Req>>> {
91        match self.listener.accept() {
92            Ok((stream, _)) => Ok(Some(SocketClient::try_from_stream(stream)?)),
93            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
94            Err(error) => Err(error),
95        }
96    }
97}
98
99pub enum PostServing {
100    Wait(Duration),
101    Yield,
102    Continue,
103    Stop,
104}