tftp_server/
server.rs

1use crate::packet::{DataBytes, ErrorCode, Packet, PacketData, PacketErr, MAX_PACKET_SIZE};
2use log::{error, info};
3use mio::udp::UdpSocket;
4use mio::*;
5use mio_extras::timer::{Timeout, Timer};
6use rand;
7use rand::Rng;
8use std::collections::HashMap;
9use std::env;
10use std::fs;
11use std::fs::File;
12use std::io;
13use std::io::{Read, Write};
14use std::net;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::result;
18use std::str::FromStr;
19use std::time::Duration;
20use std::u16;
21
22/// Timeout time until packet is re-sent.
23const TIMEOUT: u64 = 3;
24/// The token used by the server UDP socket.
25const SERVER: Token = Token(0);
26/// The token used by the timer.
27const TIMER: Token = Token(1);
28
29#[derive(Debug)]
30pub enum TftpError {
31    PacketError(PacketErr),
32    IoError(io::Error),
33    /// Error defined within the TFTP spec with an usigned integer
34    /// error code. The server should reply with an error packet
35    /// to the given socket address when handling this error.
36    TftpError(ErrorCode, SocketAddr),
37    /// Error returned when the server cannot
38    /// find a random open UDP port within 100 tries.
39    NoOpenSocket,
40    /// Error when the connection is to be closed normally
41    /// like when a data packet is smaller than 516 bytes.
42    /// The server should just close the connection without
43    /// propagating up the error.
44    CloseConnection,
45    /// Error when the socket receives a None value instead
46    /// of the source address when receiving from a socket.
47    /// This error should be ignored by the server.
48    NoneFromSocket,
49}
50
51impl From<io::Error> for TftpError {
52    fn from(err: io::Error) -> TftpError {
53        TftpError::IoError(err)
54    }
55}
56
57impl From<PacketErr> for TftpError {
58    fn from(err: PacketErr) -> TftpError {
59        TftpError::PacketError(err)
60    }
61}
62
63pub type Result<T> = result::Result<T, TftpError>;
64
65/// The state contained within a connection.
66/// A connection is started when a server socket receives
67/// a RRQ or a WRQ packet and ends when the connection socket
68/// receives a DATA packet less than 516 bytes or if the connection
69/// socket receives an invalid packet.
70struct ConnectionState {
71    /// The UDP socket for the connection that receives ACK, DATA, or ERROR packets.
72    conn: UdpSocket,
73    /// The open file either being written to or read from during the transfer.
74    /// If the connection was started with a RRQ, the file would be read from, if it
75    /// was started with a WRQ, the file would be written to.
76    file: File,
77    /// The timeout for the last packet. Every time a new packet is received, the
78    /// timeout is reset.
79    timeout: Timeout,
80    /// The current block number of the transfer. If the block numbers of the received packet
81    /// and the current block number do not match, the connection is closed.
82    block_num: u16,
83    /// The last packet sent. This is used when a timeout happens to resend the last packet.
84    last_packet: Packet,
85    /// The address of the client socket to reply to.
86    addr: SocketAddr,
87}
88
89pub struct TftpServerBuilder {
90    addr: Option<SocketAddr>,
91    serve_dir: Option<PathBuf>,
92}
93
94impl TftpServerBuilder {
95    pub fn new() -> TftpServerBuilder {
96        TftpServerBuilder {
97            addr: None,
98            serve_dir: None,
99        }
100    }
101
102    pub fn addr_opt(mut self, addr: Option<SocketAddr>) -> TftpServerBuilder {
103        self.addr = addr;
104        self
105    }
106
107    pub fn addr(self, addr: SocketAddr) -> TftpServerBuilder {
108        self.addr_opt(Some(addr))
109    }
110
111    pub fn serve_dir_opt(mut self, serve_dir: Option<PathBuf>) -> TftpServerBuilder {
112        self.serve_dir = serve_dir;
113        self
114    }
115
116    pub fn serve_dir(self, serve_dir: PathBuf) -> TftpServerBuilder {
117        self.serve_dir_opt(Some(serve_dir))
118    }
119
120    pub fn build(self) -> Result<TftpServer> {
121        let poll = Poll::new()?;
122        let socket = match self.addr {
123            Some(addr) => UdpSocket::bind(&addr)?,
124            None => UdpSocket::from_socket(create_socket(Some(Duration::from_secs(TIMEOUT)))?)?,
125        };
126        let timer = Timer::default();
127        poll.register(&socket, SERVER, Ready::all(), PollOpt::edge())?;
128        poll.register(&timer, TIMER, Ready::readable(), PollOpt::edge())?;
129        let path = match self.serve_dir {
130            Some(path) => Some(path.canonicalize()?),
131            None => None,
132        };
133
134        Ok(TftpServer {
135            new_token: 2,
136            poll,
137            timer,
138            socket,
139            connections: HashMap::new(),
140            serve_dir: path,
141        })
142    }
143}
144
145pub struct TftpServer {
146    /// The ID of a new token used for generating different tokens.
147    new_token: usize,
148    /// The event loop for handling async events.
149    poll: Poll,
150    /// The main timer that can be used to set multiple timeout events.
151    timer: Timer<Token>,
152    /// The main server socket that receives RRQ and WRQ packets
153    /// and creates a new separate UDP connection.
154    socket: UdpSocket,
155    /// The separate UDP connections for handling multiple requests.
156    connections: HashMap<Token, ConnectionState>,
157    /// The directory to serve the files in.
158    serve_dir: Option<PathBuf>,
159}
160
161impl TftpServer {
162    /// Returns a new token created from incrementing a counter.
163    fn generate_token(&mut self) -> Token {
164        let token = Token(self.new_token);
165        self.new_token += 1;
166        token
167    }
168
169    /// Cancels a connection given the connection's token. It cancels the
170    /// connection's timeout and deregisters the connection's socket from the event loop.
171    fn cancel_connection(&mut self, token: Token) -> Result<()> {
172        if let Some(conn) = self.connections.remove(&token) {
173            self.poll.deregister(&conn.conn)?;
174            self.timer
175                .cancel_timeout(&conn.timeout)
176                .expect("Error canceling timeout");
177        }
178        Ok(())
179    }
180
181    /// Resets a connection's timeout given the connection's token.
182    fn reset_timeout(&mut self, token: Token) -> Result<()> {
183        if let Some(ref mut conn) = self.connections.get_mut(&token) {
184            self.timer.cancel_timeout(&conn.timeout);
185            conn.timeout = self.timer.set_timeout(Duration::from_secs(TIMEOUT), token);
186        }
187        Ok(())
188    }
189
190    /// Handles a packet sent to the main server connection.
191    /// It opens a new UDP connection in a random port and replies with either an ACK
192    /// or a DATA packet depending on the whether it received an RRQ or a WRQ packet.
193    fn handle_server_packet(&mut self) -> Result<()> {
194        let mut buf = [0; MAX_PACKET_SIZE];
195        let (amt, src) = match self.socket.recv_from(&mut buf)? {
196            Some((amt, src)) => (amt, src),
197            None => return Err(TftpError::NoneFromSocket),
198        };
199        let packet = Packet::read(PacketData::new(buf, amt))?;
200
201        // Handle the RRQ or WRQ packet.
202        let (file, block_num, send_packet) = match packet {
203            Packet::RRQ { filename, mode } => {
204                handle_rrq_packet(filename, mode, &src, &self.serve_dir)?
205            }
206            Packet::WRQ { filename, mode } => {
207                handle_wrq_packet(filename, mode, &src, &self.serve_dir)?
208            }
209            _ => return Err(TftpError::TftpError(ErrorCode::IllegalTFTP, src)),
210        };
211
212        // Create new connection.
213        let socket = UdpSocket::from_socket(create_socket(Some(Duration::from_secs(TIMEOUT)))?)?;
214        let token = self.generate_token();
215        let timeout = self.timer.set_timeout(Duration::from_secs(TIMEOUT), token);
216        self.poll
217            .register(&socket, token, Ready::all(), PollOpt::edge())?;
218        info!("Created connection with token: {:?}", token);
219
220        socket.send_to(send_packet.clone().bytes()?.to_slice(), &src)?;
221        self.connections.insert(
222            token,
223            ConnectionState {
224                conn: socket,
225                file,
226                timeout,
227                block_num,
228                last_packet: send_packet,
229                addr: src,
230            },
231        );
232
233        Ok(())
234    }
235
236    /// Handles the event when a timer times out.
237    /// It gets the connection from the token and resends
238    /// the last packet sent from the connection.
239    fn handle_timer(&mut self) -> Result<()> {
240        let mut tokens = Vec::new();
241        while let Some(token) = self.timer.poll() {
242            tokens.push(token);
243        }
244
245        for token in tokens {
246            if let Some(ref mut conn) = self.connections.get_mut(&token) {
247                info!("Timeout: resending last packet for token: {:?}", token);
248                conn.conn
249                    .send_to(conn.last_packet.clone().bytes()?.to_slice(), &conn.addr)?;
250            }
251            self.reset_timeout(token)?;
252        }
253
254        Ok(())
255    }
256
257    /// Handles a packet sent to an open child connection.
258    fn handle_connection_packet(&mut self, token: Token) -> Result<()> {
259        if let Some(ref mut conn) = self.connections.get_mut(&token) {
260            let mut buf = [0; MAX_PACKET_SIZE];
261            let amt = match conn.conn.recv_from(&mut buf)? {
262                Some((amt, _)) => amt,
263                None => return Err(TftpError::NoneFromSocket),
264            };
265            let packet = Packet::read(PacketData::new(buf, amt))?;
266
267            match packet {
268                Packet::ACK(block_num) => handle_ack_packet(block_num, conn)?,
269                Packet::DATA {
270                    block_num,
271                    data,
272                    len,
273                } => handle_data_packet(block_num, data, len, conn)?,
274                Packet::ERROR { code, msg } => {
275                    error!("Error message received with code {:?}: {:?}", code, msg);
276                    return Err(TftpError::TftpError(code, conn.addr));
277                }
278                _ => {
279                    error!("Received invalid packet from connection");
280                    return Err(TftpError::TftpError(ErrorCode::IllegalTFTP, conn.addr));
281                }
282            }
283        }
284
285        Ok(())
286    }
287
288    /// Handles sending error packets given the error code.
289    fn handle_error(&mut self, token: Token, code: ErrorCode, addr: &SocketAddr) -> Result<()> {
290        if token == SERVER {
291            self.socket
292                .send_to(code.to_packet().bytes()?.to_slice(), addr)?;
293        } else if let Some(ref mut conn) = self.connections.get_mut(&token) {
294            conn.conn
295                .send_to(code.to_packet().bytes()?.to_slice(), addr)?;
296        }
297        Ok(())
298    }
299
300    /// Called for every event sent from the event loop. The event
301    /// is a token that can either be from the server, from an open connection,
302    /// or from a timeout timer for a connection.
303    pub fn handle_token(&mut self, token: Token) -> Result<()> {
304        match token {
305            SERVER => match self.handle_server_packet() {
306                Err(TftpError::NoneFromSocket) => {}
307                Err(TftpError::TftpError(code, addr)) => self.handle_error(token, code, &addr)?,
308                Err(e) => error!("Error: {:?}", e),
309                _ => {}
310            },
311            TIMER => self.handle_timer()?,
312            token if self.connections.get(&token).is_some() => {
313                match self.handle_connection_packet(token) {
314                    Err(TftpError::CloseConnection) => {}
315                    Err(TftpError::NoneFromSocket) => return Ok(()),
316                    Err(TftpError::TftpError(code, addr)) => {
317                        self.handle_error(token, code, &addr)?
318                    }
319                    Err(e) => error!("Error: {:?}", e),
320                    _ => {
321                        self.reset_timeout(token)?;
322                        return Ok(());
323                    }
324                }
325
326                info!("Closing connection with token {:?}", token);
327                self.cancel_connection(token)?;
328                return Ok(());
329            }
330            _ => unreachable!(),
331        }
332
333        Ok(())
334    }
335
336    /// Runs the server's event loop.
337    pub fn run(&mut self) -> Result<()> {
338        let mut events = Events::with_capacity(1024);
339        loop {
340            self.poll.poll(&mut events, None)?;
341
342            for event in events.iter() {
343                self.handle_token(event.token())?;
344            }
345        }
346    }
347
348    /// Returns the socket address of the server socket.
349    pub fn local_addr(&self) -> Result<SocketAddr> {
350        Ok(self.socket.local_addr()?)
351    }
352}
353
354/// Creates a std::net::UdpSocket on a random open UDP port.
355/// The range of valid ports is from 0 to 65535 and if the function
356/// cannot find a open port within 100 different random ports it returns an error.
357pub fn create_socket(timeout: Option<Duration>) -> Result<net::UdpSocket> {
358    let mut num_failures = 0;
359    let mut past_ports = HashMap::new();
360    loop {
361        let port = rand::thread_rng().gen_range(0, 65535);
362        // Ignore ports that already failed.
363        if past_ports.get(&port).is_some() {
364            continue;
365        }
366
367        let addr = format!("127.0.0.1:{}", port);
368        let socket_addr = SocketAddr::from_str(addr.as_str()).expect("Error parsing address");
369        match net::UdpSocket::bind(&socket_addr) {
370            Ok(socket) => {
371                if let Some(timeout) = timeout {
372                    socket.set_read_timeout(Some(timeout))?;
373                    socket.set_write_timeout(Some(timeout))?;
374                }
375                return Ok(socket);
376            }
377            Err(_) => {
378                past_ports.insert(port, true);
379                num_failures += 1;
380                if num_failures > 100 {
381                    return Err(TftpError::NoOpenSocket);
382                }
383            }
384        }
385    }
386}
387
388/// Increments the block number and handles wraparound to 0 instead of overflow.
389pub fn incr_block_num(block_num: &mut u16) {
390    if *block_num == u16::MAX - 1 {
391        *block_num = 0;
392    } else {
393        *block_num += 1;
394    }
395}
396
397fn handle_rrq_packet(
398    filename: String,
399    mode: String,
400    addr: &SocketAddr,
401    serve_dir: &Option<PathBuf>,
402) -> Result<(File, u16, Packet)> {
403    info!(
404        "Received RRQ packet with filename {} and mode {}",
405        filename, mode
406    );
407
408    let path = path_from_filename(filename, serve_dir, addr)?;
409
410    let mut file =
411        File::open(&path).map_err(|_| TftpError::TftpError(ErrorCode::FileNotFound, *addr))?;
412    let block_num = 1;
413
414    let mut buf = [0; 512];
415    let amount = file.read(&mut buf)?;
416
417    // Reply with first data packet with a block number of 1.
418    let last_packet = Packet::DATA {
419        block_num,
420        data: DataBytes(buf),
421        len: amount,
422    };
423
424    Ok((file, block_num, last_packet))
425}
426
427fn handle_wrq_packet(
428    filename: String,
429    mode: String,
430    addr: &SocketAddr,
431    serve_dir: &Option<PathBuf>,
432) -> Result<(File, u16, Packet)> {
433    info!(
434        "Received WRQ packet with filename {} and mode {}",
435        filename, mode
436    );
437
438    let path = path_from_filename(filename, serve_dir, addr)?;
439    if fs::metadata(&path).is_ok() {
440        return Err(TftpError::TftpError(ErrorCode::FileExists, *addr));
441    }
442    let file = File::create(&path)?;
443    let block_num = 0;
444
445    // Reply with ACK with a block number of 0.
446    let last_packet = Packet::ACK(block_num);
447
448    Ok((file, block_num, last_packet))
449}
450
451fn handle_ack_packet(block_num: u16, conn: &mut ConnectionState) -> Result<()> {
452    info!("Received ACK with block number {}", block_num);
453    if block_num != conn.block_num {
454        return Ok(());
455    }
456
457    incr_block_num(&mut conn.block_num);
458    let mut buf = [0; 512];
459    let amount = conn.file.read(&mut buf)?;
460
461    // Send next data packet.
462    conn.last_packet = Packet::DATA {
463        block_num: conn.block_num,
464        data: DataBytes(buf),
465        len: amount,
466    };
467    conn.conn
468        .send_to(conn.last_packet.clone().bytes()?.to_slice(), &conn.addr)?;
469
470    if amount < 512 {
471        Err(TftpError::CloseConnection)
472    } else {
473        Ok(())
474    }
475}
476
477fn handle_data_packet(
478    block_num: u16,
479    data: DataBytes,
480    len: usize,
481    conn: &mut ConnectionState,
482) -> Result<()> {
483    info!("Received data with block number {}", block_num);
484
485    incr_block_num(&mut conn.block_num);
486    if block_num != conn.block_num {
487        return Ok(());
488    }
489
490    conn.file.write_all(&data.0[0..len])?;
491
492    // Send ACK packet for data.
493    conn.last_packet = Packet::ACK(conn.block_num);
494    conn.conn
495        .send_to(conn.last_packet.clone().bytes()?.to_slice(), &conn.addr)?;
496
497    if len < 512 {
498        Err(TftpError::CloseConnection)
499    } else {
500        Ok(())
501    }
502}
503
504fn path_from_filename(
505    filename: String,
506    serve_dir: &Option<PathBuf>,
507    addr: &SocketAddr,
508) -> Result<PathBuf> {
509    if filename.contains("..") || filename.starts_with('/') || filename.starts_with("~/") {
510        return Err(TftpError::TftpError(ErrorCode::AccessViolation, *addr));
511    }
512    let mut path = match serve_dir {
513        Some(dir) => dir.clone(),
514        None => env::current_dir()?,
515    };
516    path.push(&filename);
517    Ok(path)
518}