1use crate::packet::{DataBytes, ErrorCode, Packet, PacketData, PacketErr, MAX_PACKET_SIZE};
2use log::{error, info};
3use mio::udp::UdpSocket;
4use mio::*;
5use mio_extras::timer::{Timeout, Timer};
6use rand;
7use rand::Rng;
8use std::collections::HashMap;
9use std::env;
10use std::fs;
11use std::fs::File;
12use std::io;
13use std::io::{Read, Write};
14use std::net;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::result;
18use std::str::FromStr;
19use std::time::Duration;
20use std::u16;
21
22const TIMEOUT: u64 = 3;
24const SERVER: Token = Token(0);
26const TIMER: Token = Token(1);
28
29#[derive(Debug)]
30pub enum TftpError {
31 PacketError(PacketErr),
32 IoError(io::Error),
33 TftpError(ErrorCode, SocketAddr),
37 NoOpenSocket,
40 CloseConnection,
45 NoneFromSocket,
49}
50
51impl From<io::Error> for TftpError {
52 fn from(err: io::Error) -> TftpError {
53 TftpError::IoError(err)
54 }
55}
56
57impl From<PacketErr> for TftpError {
58 fn from(err: PacketErr) -> TftpError {
59 TftpError::PacketError(err)
60 }
61}
62
63pub type Result<T> = result::Result<T, TftpError>;
64
65struct ConnectionState {
71 conn: UdpSocket,
73 file: File,
77 timeout: Timeout,
80 block_num: u16,
83 last_packet: Packet,
85 addr: SocketAddr,
87}
88
89pub struct TftpServerBuilder {
90 addr: Option<SocketAddr>,
91 serve_dir: Option<PathBuf>,
92}
93
94impl TftpServerBuilder {
95 pub fn new() -> TftpServerBuilder {
96 TftpServerBuilder {
97 addr: None,
98 serve_dir: None,
99 }
100 }
101
102 pub fn addr_opt(mut self, addr: Option<SocketAddr>) -> TftpServerBuilder {
103 self.addr = addr;
104 self
105 }
106
107 pub fn addr(self, addr: SocketAddr) -> TftpServerBuilder {
108 self.addr_opt(Some(addr))
109 }
110
111 pub fn serve_dir_opt(mut self, serve_dir: Option<PathBuf>) -> TftpServerBuilder {
112 self.serve_dir = serve_dir;
113 self
114 }
115
116 pub fn serve_dir(self, serve_dir: PathBuf) -> TftpServerBuilder {
117 self.serve_dir_opt(Some(serve_dir))
118 }
119
120 pub fn build(self) -> Result<TftpServer> {
121 let poll = Poll::new()?;
122 let socket = match self.addr {
123 Some(addr) => UdpSocket::bind(&addr)?,
124 None => UdpSocket::from_socket(create_socket(Some(Duration::from_secs(TIMEOUT)))?)?,
125 };
126 let timer = Timer::default();
127 poll.register(&socket, SERVER, Ready::all(), PollOpt::edge())?;
128 poll.register(&timer, TIMER, Ready::readable(), PollOpt::edge())?;
129 let path = match self.serve_dir {
130 Some(path) => Some(path.canonicalize()?),
131 None => None,
132 };
133
134 Ok(TftpServer {
135 new_token: 2,
136 poll,
137 timer,
138 socket,
139 connections: HashMap::new(),
140 serve_dir: path,
141 })
142 }
143}
144
145pub struct TftpServer {
146 new_token: usize,
148 poll: Poll,
150 timer: Timer<Token>,
152 socket: UdpSocket,
155 connections: HashMap<Token, ConnectionState>,
157 serve_dir: Option<PathBuf>,
159}
160
161impl TftpServer {
162 fn generate_token(&mut self) -> Token {
164 let token = Token(self.new_token);
165 self.new_token += 1;
166 token
167 }
168
169 fn cancel_connection(&mut self, token: Token) -> Result<()> {
172 if let Some(conn) = self.connections.remove(&token) {
173 self.poll.deregister(&conn.conn)?;
174 self.timer
175 .cancel_timeout(&conn.timeout)
176 .expect("Error canceling timeout");
177 }
178 Ok(())
179 }
180
181 fn reset_timeout(&mut self, token: Token) -> Result<()> {
183 if let Some(ref mut conn) = self.connections.get_mut(&token) {
184 self.timer.cancel_timeout(&conn.timeout);
185 conn.timeout = self.timer.set_timeout(Duration::from_secs(TIMEOUT), token);
186 }
187 Ok(())
188 }
189
190 fn handle_server_packet(&mut self) -> Result<()> {
194 let mut buf = [0; MAX_PACKET_SIZE];
195 let (amt, src) = match self.socket.recv_from(&mut buf)? {
196 Some((amt, src)) => (amt, src),
197 None => return Err(TftpError::NoneFromSocket),
198 };
199 let packet = Packet::read(PacketData::new(buf, amt))?;
200
201 let (file, block_num, send_packet) = match packet {
203 Packet::RRQ { filename, mode } => {
204 handle_rrq_packet(filename, mode, &src, &self.serve_dir)?
205 }
206 Packet::WRQ { filename, mode } => {
207 handle_wrq_packet(filename, mode, &src, &self.serve_dir)?
208 }
209 _ => return Err(TftpError::TftpError(ErrorCode::IllegalTFTP, src)),
210 };
211
212 let socket = UdpSocket::from_socket(create_socket(Some(Duration::from_secs(TIMEOUT)))?)?;
214 let token = self.generate_token();
215 let timeout = self.timer.set_timeout(Duration::from_secs(TIMEOUT), token);
216 self.poll
217 .register(&socket, token, Ready::all(), PollOpt::edge())?;
218 info!("Created connection with token: {:?}", token);
219
220 socket.send_to(send_packet.clone().bytes()?.to_slice(), &src)?;
221 self.connections.insert(
222 token,
223 ConnectionState {
224 conn: socket,
225 file,
226 timeout,
227 block_num,
228 last_packet: send_packet,
229 addr: src,
230 },
231 );
232
233 Ok(())
234 }
235
236 fn handle_timer(&mut self) -> Result<()> {
240 let mut tokens = Vec::new();
241 while let Some(token) = self.timer.poll() {
242 tokens.push(token);
243 }
244
245 for token in tokens {
246 if let Some(ref mut conn) = self.connections.get_mut(&token) {
247 info!("Timeout: resending last packet for token: {:?}", token);
248 conn.conn
249 .send_to(conn.last_packet.clone().bytes()?.to_slice(), &conn.addr)?;
250 }
251 self.reset_timeout(token)?;
252 }
253
254 Ok(())
255 }
256
257 fn handle_connection_packet(&mut self, token: Token) -> Result<()> {
259 if let Some(ref mut conn) = self.connections.get_mut(&token) {
260 let mut buf = [0; MAX_PACKET_SIZE];
261 let amt = match conn.conn.recv_from(&mut buf)? {
262 Some((amt, _)) => amt,
263 None => return Err(TftpError::NoneFromSocket),
264 };
265 let packet = Packet::read(PacketData::new(buf, amt))?;
266
267 match packet {
268 Packet::ACK(block_num) => handle_ack_packet(block_num, conn)?,
269 Packet::DATA {
270 block_num,
271 data,
272 len,
273 } => handle_data_packet(block_num, data, len, conn)?,
274 Packet::ERROR { code, msg } => {
275 error!("Error message received with code {:?}: {:?}", code, msg);
276 return Err(TftpError::TftpError(code, conn.addr));
277 }
278 _ => {
279 error!("Received invalid packet from connection");
280 return Err(TftpError::TftpError(ErrorCode::IllegalTFTP, conn.addr));
281 }
282 }
283 }
284
285 Ok(())
286 }
287
288 fn handle_error(&mut self, token: Token, code: ErrorCode, addr: &SocketAddr) -> Result<()> {
290 if token == SERVER {
291 self.socket
292 .send_to(code.to_packet().bytes()?.to_slice(), addr)?;
293 } else if let Some(ref mut conn) = self.connections.get_mut(&token) {
294 conn.conn
295 .send_to(code.to_packet().bytes()?.to_slice(), addr)?;
296 }
297 Ok(())
298 }
299
300 pub fn handle_token(&mut self, token: Token) -> Result<()> {
304 match token {
305 SERVER => match self.handle_server_packet() {
306 Err(TftpError::NoneFromSocket) => {}
307 Err(TftpError::TftpError(code, addr)) => self.handle_error(token, code, &addr)?,
308 Err(e) => error!("Error: {:?}", e),
309 _ => {}
310 },
311 TIMER => self.handle_timer()?,
312 token if self.connections.get(&token).is_some() => {
313 match self.handle_connection_packet(token) {
314 Err(TftpError::CloseConnection) => {}
315 Err(TftpError::NoneFromSocket) => return Ok(()),
316 Err(TftpError::TftpError(code, addr)) => {
317 self.handle_error(token, code, &addr)?
318 }
319 Err(e) => error!("Error: {:?}", e),
320 _ => {
321 self.reset_timeout(token)?;
322 return Ok(());
323 }
324 }
325
326 info!("Closing connection with token {:?}", token);
327 self.cancel_connection(token)?;
328 return Ok(());
329 }
330 _ => unreachable!(),
331 }
332
333 Ok(())
334 }
335
336 pub fn run(&mut self) -> Result<()> {
338 let mut events = Events::with_capacity(1024);
339 loop {
340 self.poll.poll(&mut events, None)?;
341
342 for event in events.iter() {
343 self.handle_token(event.token())?;
344 }
345 }
346 }
347
348 pub fn local_addr(&self) -> Result<SocketAddr> {
350 Ok(self.socket.local_addr()?)
351 }
352}
353
354pub fn create_socket(timeout: Option<Duration>) -> Result<net::UdpSocket> {
358 let mut num_failures = 0;
359 let mut past_ports = HashMap::new();
360 loop {
361 let port = rand::thread_rng().gen_range(0, 65535);
362 if past_ports.get(&port).is_some() {
364 continue;
365 }
366
367 let addr = format!("127.0.0.1:{}", port);
368 let socket_addr = SocketAddr::from_str(addr.as_str()).expect("Error parsing address");
369 match net::UdpSocket::bind(&socket_addr) {
370 Ok(socket) => {
371 if let Some(timeout) = timeout {
372 socket.set_read_timeout(Some(timeout))?;
373 socket.set_write_timeout(Some(timeout))?;
374 }
375 return Ok(socket);
376 }
377 Err(_) => {
378 past_ports.insert(port, true);
379 num_failures += 1;
380 if num_failures > 100 {
381 return Err(TftpError::NoOpenSocket);
382 }
383 }
384 }
385 }
386}
387
388pub fn incr_block_num(block_num: &mut u16) {
390 if *block_num == u16::MAX - 1 {
391 *block_num = 0;
392 } else {
393 *block_num += 1;
394 }
395}
396
397fn handle_rrq_packet(
398 filename: String,
399 mode: String,
400 addr: &SocketAddr,
401 serve_dir: &Option<PathBuf>,
402) -> Result<(File, u16, Packet)> {
403 info!(
404 "Received RRQ packet with filename {} and mode {}",
405 filename, mode
406 );
407
408 let path = path_from_filename(filename, serve_dir, addr)?;
409
410 let mut file =
411 File::open(&path).map_err(|_| TftpError::TftpError(ErrorCode::FileNotFound, *addr))?;
412 let block_num = 1;
413
414 let mut buf = [0; 512];
415 let amount = file.read(&mut buf)?;
416
417 let last_packet = Packet::DATA {
419 block_num,
420 data: DataBytes(buf),
421 len: amount,
422 };
423
424 Ok((file, block_num, last_packet))
425}
426
427fn handle_wrq_packet(
428 filename: String,
429 mode: String,
430 addr: &SocketAddr,
431 serve_dir: &Option<PathBuf>,
432) -> Result<(File, u16, Packet)> {
433 info!(
434 "Received WRQ packet with filename {} and mode {}",
435 filename, mode
436 );
437
438 let path = path_from_filename(filename, serve_dir, addr)?;
439 if fs::metadata(&path).is_ok() {
440 return Err(TftpError::TftpError(ErrorCode::FileExists, *addr));
441 }
442 let file = File::create(&path)?;
443 let block_num = 0;
444
445 let last_packet = Packet::ACK(block_num);
447
448 Ok((file, block_num, last_packet))
449}
450
451fn handle_ack_packet(block_num: u16, conn: &mut ConnectionState) -> Result<()> {
452 info!("Received ACK with block number {}", block_num);
453 if block_num != conn.block_num {
454 return Ok(());
455 }
456
457 incr_block_num(&mut conn.block_num);
458 let mut buf = [0; 512];
459 let amount = conn.file.read(&mut buf)?;
460
461 conn.last_packet = Packet::DATA {
463 block_num: conn.block_num,
464 data: DataBytes(buf),
465 len: amount,
466 };
467 conn.conn
468 .send_to(conn.last_packet.clone().bytes()?.to_slice(), &conn.addr)?;
469
470 if amount < 512 {
471 Err(TftpError::CloseConnection)
472 } else {
473 Ok(())
474 }
475}
476
477fn handle_data_packet(
478 block_num: u16,
479 data: DataBytes,
480 len: usize,
481 conn: &mut ConnectionState,
482) -> Result<()> {
483 info!("Received data with block number {}", block_num);
484
485 incr_block_num(&mut conn.block_num);
486 if block_num != conn.block_num {
487 return Ok(());
488 }
489
490 conn.file.write_all(&data.0[0..len])?;
491
492 conn.last_packet = Packet::ACK(conn.block_num);
494 conn.conn
495 .send_to(conn.last_packet.clone().bytes()?.to_slice(), &conn.addr)?;
496
497 if len < 512 {
498 Err(TftpError::CloseConnection)
499 } else {
500 Ok(())
501 }
502}
503
504fn path_from_filename(
505 filename: String,
506 serve_dir: &Option<PathBuf>,
507 addr: &SocketAddr,
508) -> Result<PathBuf> {
509 if filename.contains("..") || filename.starts_with('/') || filename.starts_with("~/") {
510 return Err(TftpError::TftpError(ErrorCode::AccessViolation, *addr));
511 }
512 let mut path = match serve_dir {
513 Some(dir) => dir.clone(),
514 None => env::current_dir()?,
515 };
516 path.push(&filename);
517 Ok(path)
518}