1use mio::{
2 net::{TcpListener, TcpStream},
3 Events, Interest, Poll, Token,
4};
5use std::{
6 io::{Read, Result, Write},
7 net::SocketAddr,
8 time::Duration,
9};
10
11pub trait Socket: Sized {
12 type Server: ConnectionHandler<Self>;
13
14 fn get_stream(&mut self) -> &mut TcpStream;
15
16 fn get_token(&self) -> &Token;
17
18 fn get_addr(&self) -> &SocketAddr;
19}
20
21pub struct ConnectionPool<Player: Socket> {
22 pub indexed_connection: Vec<Option<Player>>,
23 pub index_queue: Vec<usize>,
24}
25
26impl<Player: Socket> ConnectionPool<Player> {
27 fn get_socket(&mut self, token_index: usize) -> &mut Player {
28 unsafe { self.indexed_connection.get_unchecked_mut(token_index) }
29 .as_mut()
30 .expect("socket is none")
31 }
32
33 fn remove_socket(&mut self, token_index: usize) {
34 self.index_queue.push(token_index);
35 self.indexed_connection[token_index] = None;
36 }
37}
38
39pub trait ConnectionHandler<Player: Socket>: Sized {
40 fn handle_connection_accept(
41 &mut self,
42 stream: TcpStream,
43 token: Token,
44 addr: SocketAddr,
45 ) -> Player;
46
47 fn handle_connection_read(&mut self, socket: &mut Player, buf: &[u8]) -> Result<()>;
48
49 fn handle_connection_closed(&mut self, socket: &mut Player);
50
51 fn handle_update(&mut self);
52}
53
54pub struct Selector<Player: Socket, Server: ConnectionHandler<Player>> {
55 listener: TcpListener,
56 poll: Poll,
57 connection_pool: ConnectionPool<Player>,
58 connection_handler: Box<Server>,
59}
60
61impl<Player: Socket, Server: ConnectionHandler<Player>> Selector<Player, Server> {
62 pub fn new<const CONNECTION_POOL_SIZE: usize>(
63 addr: SocketAddr,
64 connection_handler: Server,
65 ) -> Selector<Player, Server> {
66 Selector {
67 listener: TcpListener::bind(addr).expect("Cannot start server"),
68 poll: Poll::new().expect("cannot create poll"),
69 connection_pool: ConnectionPool {
70 indexed_connection: Vec::with_capacity(CONNECTION_POOL_SIZE),
71 index_queue: Vec::with_capacity(CONNECTION_POOL_SIZE),
72 },
73 connection_handler: Box::new(connection_handler),
74 }
75 }
76
77 pub fn start_selection_loop<const MAX_READ_BUFFER_SIZE: usize>(
78 mut self,
79 timeout: Option<Duration>,
80 ) {
81 let server_token = Token(usize::MAX);
82 let poll = &mut self.poll;
83 let listener = &mut self.listener;
84 let connection_handler = &mut self.connection_handler;
85 let connection_pool = &mut self.connection_pool;
86 poll.registry()
87 .register(listener, server_token, Interest::READABLE)
88 .expect("Cannot reigster server to poll");
89 let buf = &mut [0u8; MAX_READ_BUFFER_SIZE];
90 let events_capacity = 128;
91 let events = &mut Events::with_capacity(events_capacity);
92 loop {
93 if let Err(_) = poll.poll(events, timeout) {
94 continue;
95 }
96 connection_handler.handle_update();
97 for event in events.iter() {
98 let token = event.token();
99 if token == server_token {
100 if let Ok((stream, addr)) = listener.accept() {
101 if let Some(index) = connection_pool.index_queue.pop() {
102 let token = Token(index);
103 let mut connection =
104 connection_handler.handle_connection_accept(stream, token, addr);
105 poll.registry()
106 .register(connection.get_stream(), Token(index), Interest::READABLE)
107 .expect("poll register");
108 connection_pool.indexed_connection[index] = Some(connection);
109 } else {
110 let index = connection_pool.indexed_connection.len();
111 let token = Token(index);
112 let mut connection =
113 connection_handler.handle_connection_accept(stream, token, addr);
114 poll.registry()
115 .register(connection.get_stream(), Token(index), Interest::READABLE)
116 .expect("poll register");
117 connection_pool.indexed_connection.push(Some(connection));
118 }
119 }
120 } else {
121 let token_index = token.0;
122 if event.is_readable() {
123 let player = connection_pool.get_socket(token_index);
124 let stream = player.get_stream();
125 let read_result = stream.read(buf);
126 if read_result.is_err() {
127 poll.registry()
128 .deregister(player.get_stream())
129 .expect("cannot deregister socket");
130 connection_handler.handle_connection_closed(player);
131 connection_pool.remove_socket(token_index);
132 continue;
133 }
134 let read = read_result.unwrap();
135 if read == 0 {
136 poll.registry()
137 .deregister(player.get_stream())
138 .expect("cannot deregister socket");
139 connection_handler.handle_connection_closed(player);
140 connection_pool.remove_socket(token_index);
141 continue;
142 } else {
143 let read_buf = &buf[0..read];
144 if let Err(err) =
145 connection_handler.handle_connection_read(player, read_buf)
146 {
147 player.get_stream().flush();
148 println!("Read handle error: {}", err);
149 poll.registry()
150 .deregister(player.get_stream())
151 .expect("cannot deregister socket");
152 connection_handler.handle_connection_closed(player);
153 connection_pool.remove_socket(token_index);
154 continue;
155 }
156 }
157 }
158 }
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::net::SocketAddr;
167
168 use mio::{net::TcpStream, Token};
169
170 use crate::{ConnectionHandler, Selector, Socket};
171
172 #[test]
173 fn start_selector() {
174 let server = MyServer {};
175 let addr = "0.0.0.0:1234".parse().unwrap();
176 let selector = Selector::new::<256>(addr, server);
177 selector.start_selection_loop::<10000>(None)
178 }
179
180 struct MyServer {}
181
182 impl ConnectionHandler<Player> for MyServer {
183 fn handle_connection_accept(
184 &mut self,
185 stream: TcpStream,
186 token: Token,
187 addr: SocketAddr,
188 ) -> Player {
189 Player {
190 stream,
191 token,
192 addr,
193 }
194 }
195
196 fn handle_connection_read(
197 &mut self,
198 _socket: &mut Player,
199 _buf: &[u8],
200 ) -> std::io::Result<()> {
201 Ok(())
203 }
204
205 fn handle_connection_closed(&mut self, _socket: &mut Player) {
206 }
208
209 fn handle_update(&mut self) {
210 }
212 }
213
214 struct Player {
215 stream: TcpStream,
216 token: Token,
217 addr: SocketAddr,
218 }
219
220 impl Socket for Player {
221 type Server = MyServer;
222
223 fn get_stream(&mut self) -> &mut TcpStream {
224 &mut self.stream
225 }
226
227 fn get_token(&self) -> &Token {
228 &self.token
229 }
230
231 fn get_addr(&self) -> &SocketAddr {
232 &self.addr
233 }
234 }
235}