socket_selector/
lib.rs

1use mio::{
2    net::{TcpListener, TcpStream},
3    Events, Interest, Poll, Token,
4};
5use std::{
6    io::{Read, Result, Write},
7    net::SocketAddr,
8    time::Duration,
9};
10
11pub trait Socket: Sized {
12    type Server: ConnectionHandler<Self>;
13
14    fn get_stream(&mut self) -> &mut TcpStream;
15
16    fn get_token(&self) -> &Token;
17
18    fn get_addr(&self) -> &SocketAddr;
19}
20
21pub struct ConnectionPool<Player: Socket> {
22    pub indexed_connection: Vec<Option<Player>>,
23    pub index_queue: Vec<usize>,
24}
25
26impl<Player: Socket> ConnectionPool<Player> {
27    fn get_socket(&mut self, token_index: usize) -> &mut Player {
28        unsafe { self.indexed_connection.get_unchecked_mut(token_index) }
29            .as_mut()
30            .expect("socket is none")
31    }
32
33    fn remove_socket(&mut self, token_index: usize) {
34        self.index_queue.push(token_index);
35        self.indexed_connection[token_index] = None;
36    }
37}
38
39pub trait ConnectionHandler<Player: Socket>: Sized {
40    fn handle_connection_accept(
41        &mut self,
42        stream: TcpStream,
43        token: Token,
44        addr: SocketAddr,
45    ) -> Player;
46
47    fn handle_connection_read(&mut self, socket: &mut Player, buf: &[u8]) -> Result<()>;
48
49    fn handle_connection_closed(&mut self, socket: &mut Player);
50
51    fn handle_update(&mut self);
52}
53
54pub struct Selector<Player: Socket, Server: ConnectionHandler<Player>> {
55    listener: TcpListener,
56    poll: Poll,
57    connection_pool: ConnectionPool<Player>,
58    connection_handler: Box<Server>,
59}
60
61impl<Player: Socket, Server: ConnectionHandler<Player>> Selector<Player, Server> {
62    pub fn new<const CONNECTION_POOL_SIZE: usize>(
63        addr: SocketAddr,
64        connection_handler: Server,
65    ) -> Selector<Player, Server> {
66        Selector {
67            listener: TcpListener::bind(addr).expect("Cannot start server"),
68            poll: Poll::new().expect("cannot create poll"),
69            connection_pool: ConnectionPool {
70                indexed_connection: Vec::with_capacity(CONNECTION_POOL_SIZE),
71                index_queue: Vec::with_capacity(CONNECTION_POOL_SIZE),
72            },
73            connection_handler: Box::new(connection_handler),
74        }
75    }
76
77    pub fn start_selection_loop<const MAX_READ_BUFFER_SIZE: usize>(
78        mut self,
79        timeout: Option<Duration>,
80    ) {
81        let server_token = Token(usize::MAX);
82        let poll = &mut self.poll;
83        let listener = &mut self.listener;
84        let connection_handler = &mut self.connection_handler;
85        let connection_pool = &mut self.connection_pool;
86        poll.registry()
87            .register(listener, server_token, Interest::READABLE)
88            .expect("Cannot reigster server to poll");
89        let buf = &mut [0u8; MAX_READ_BUFFER_SIZE];
90        let events_capacity = 128;
91        let events = &mut Events::with_capacity(events_capacity);
92        loop {
93            if let Err(_) = poll.poll(events, timeout) {
94                continue;
95            }
96            connection_handler.handle_update();
97            for event in events.iter() {
98                let token = event.token();
99                if token == server_token {
100                    if let Ok((stream, addr)) = listener.accept() {
101                        if let Some(index) = connection_pool.index_queue.pop() {
102                            let token = Token(index);
103                            let mut connection =
104                                connection_handler.handle_connection_accept(stream, token, addr);
105                            poll.registry()
106                                .register(connection.get_stream(), Token(index), Interest::READABLE)
107                                .expect("poll register");
108                            connection_pool.indexed_connection[index] = Some(connection);
109                        } else {
110                            let index = connection_pool.indexed_connection.len();
111                            let token = Token(index);
112                            let mut connection =
113                                connection_handler.handle_connection_accept(stream, token, addr);
114                            poll.registry()
115                                .register(connection.get_stream(), Token(index), Interest::READABLE)
116                                .expect("poll register");
117                            connection_pool.indexed_connection.push(Some(connection));
118                        }
119                    }
120                } else {
121                    let token_index = token.0;
122                    if event.is_readable() {
123                        let player = connection_pool.get_socket(token_index);
124                        let stream = player.get_stream();
125                        let read_result = stream.read(buf);
126                        if read_result.is_err() {
127                            poll.registry()
128                                .deregister(player.get_stream())
129                                .expect("cannot deregister socket");
130                            connection_handler.handle_connection_closed(player);
131                            connection_pool.remove_socket(token_index);
132                            continue;
133                        }
134                        let read = read_result.unwrap();
135                        if read == 0 {
136                            poll.registry()
137                                .deregister(player.get_stream())
138                                .expect("cannot deregister socket");
139                            connection_handler.handle_connection_closed(player);
140                            connection_pool.remove_socket(token_index);
141                            continue;
142                        } else {
143                            let read_buf = &buf[0..read];
144                            if let Err(err) =
145                                connection_handler.handle_connection_read(player, read_buf)
146                            {
147                                player.get_stream().flush();
148                                println!("Read handle error: {}", err);
149                                poll.registry()
150                                    .deregister(player.get_stream())
151                                    .expect("cannot deregister socket");
152                                connection_handler.handle_connection_closed(player);
153                                connection_pool.remove_socket(token_index);
154                                continue;
155                            }
156                        }
157                    }
158                }
159            }
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::net::SocketAddr;
167
168    use mio::{net::TcpStream, Token};
169
170    use crate::{ConnectionHandler, Selector, Socket};
171
172    #[test]
173    fn start_selector() {
174        let server = MyServer {};
175        let addr = "0.0.0.0:1234".parse().unwrap();
176        let selector = Selector::new::<256>(addr, server);
177        selector.start_selection_loop::<10000>(None)
178    }
179
180    struct MyServer {}
181
182    impl ConnectionHandler<Player> for MyServer {
183        fn handle_connection_accept(
184            &mut self,
185            stream: TcpStream,
186            token: Token,
187            addr: SocketAddr,
188        ) -> Player {
189            Player {
190                stream,
191                token,
192                addr,
193            }
194        }
195
196        fn handle_connection_read(
197            &mut self,
198            _socket: &mut Player,
199            _buf: &[u8],
200        ) -> std::io::Result<()> {
201            //read
202            Ok(())
203        }
204
205        fn handle_connection_closed(&mut self, _socket: &mut Player) {
206            //on closed
207        }
208
209        fn handle_update(&mut self) {
210            //update
211        }
212    }
213
214    struct Player {
215        stream: TcpStream,
216        token: Token,
217        addr: SocketAddr,
218    }
219
220    impl Socket for Player {
221        type Server = MyServer;
222
223        fn get_stream(&mut self) -> &mut TcpStream {
224            &mut self.stream
225        }
226
227        fn get_token(&self) -> &Token {
228            &self.token
229        }
230
231        fn get_addr(&self) -> &SocketAddr {
232            &self.addr
233        }
234    }
235}