tftp_packet/
lib.rs

1/*!
2# An implementation of the tftp packet
3
4This implements only conversion into and from bytes for a tftp packet and includes some generic enums and a custom Error type.
5
6## Example
7
8```rust
9use tftp_packet::Packet;
10use tftp_packet::Mode;
11
12let packet = Packet::RRQ{ filename: "test.txt".to_string(), mode: Mode::Octet };
13let bytes = packet.clone().to_bytes();
14assert_eq!(bytes, [0, 1, 116, 101, 115, 116, 46, 116, 120, 116, 0, 111, 99, 116, 101, 116, 0]);
15assert_eq!(Packet::from_bytes(&bytes).unwrap(), packet);
16```
17*/
18mod parsing;
19
20use std::convert::TryFrom;
21use std::{error::Error, fmt::Display};
22
23use parsing::parse_block_number;
24use parsing::parse_filename;
25use parsing::parse_mode;
26use parsing::take_u16;
27use parsing::{parse_error_code, parse_error_message};
28
29/// The error type for the tftp packet
30#[derive(Debug, PartialEq)]
31pub enum PacketError {
32    /// General errors with an error message
33    InvalidPacket(String),
34    /// Invalid packet length with the expected length as field
35    InvalidPacketLength(u16),
36    /// Invalid Opcode with an error message
37    InvalidOpcode(String),
38}
39
40impl Display for PacketError {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        use PacketError::*;
43
44        match self {
45            InvalidPacket(s) => write!(f, "InvalidPacket: {}", s),
46            InvalidOpcode(s) => write!(f, "InvalidOpcode: {}", s),
47            InvalidPacketLength(s) => write!(f, "InvalidPacketLength: Expected {} bytes", s),
48        }
49    }
50}
51
52impl Error for PacketError {}
53
54/// All tftp opcodes defined in rfc1350
55///
56/// ```
57/// # use tftp_packet::Opcode;
58/// assert_eq!(Opcode::RRQ as u16, 0);
59/// assert_eq!(Opcode::try_from(1), Ok(Opcode::RRQ));
60/// ```
61#[derive(Debug, PartialEq, Clone)]
62pub enum Opcode {
63    RRQ,
64    WRQ,
65    DATA,
66    ACK,
67    ERROR,
68}
69
70impl TryFrom<u16> for Opcode {
71    type Error = &'static str;
72
73    fn try_from(value: u16) -> Result<Self, Self::Error> {
74        Ok(match value {
75            1 => Opcode::RRQ,
76            2 => Opcode::WRQ,
77            3 => Opcode::DATA,
78            4 => Opcode::ACK,
79            5 => Opcode::ERROR,
80            _ => Err("Invalid opcode: {}")?,
81        })
82    }
83}
84
85impl TryFrom<&[u8; 2]> for Opcode {
86    type Error = &'static str;
87
88    fn try_from(value: &[u8; 2]) -> Result<Self, Self::Error> {
89        let opcode = u16::from_be_bytes(*value);
90        Self::try_from(opcode)
91    }
92}
93
94/// All modes defined in rfc1350
95///
96/// ```
97/// # use tftp_packet::Mode;
98/// let mode: &str = Mode::Netascii.as_str();
99/// assert_eq!(mode, "netascii");
100/// assert_eq!(Mode::try_from("netascii"), Ok(Mode::Netascii));
101/// ```
102#[derive(Debug, PartialEq, Clone)]
103pub enum Mode {
104    Netascii,
105    Octet,
106    Mail,
107}
108
109impl Mode {
110    pub fn as_str(&self) -> &'static str {
111        self.into()
112    }
113}
114
115impl TryFrom<&str> for Mode {
116    type Error = &'static str;
117
118    fn try_from(value: &str) -> Result<Self, Self::Error> {
119        Ok(match value {
120            "netascii" => Mode::Netascii,
121            "octet" => Mode::Octet,
122            "mail" => Mode::Mail,
123            _ => Err("Invalid mode")?,
124        })
125    }
126}
127
128impl Into<&str> for &Mode {
129    fn into(self) -> &'static str {
130        match self {
131            Mode::Netascii => "netascii",
132            Mode::Octet => "octet",
133            Mode::Mail => "mail",
134        }
135    }
136}
137
138/// All error codes defined in rfc1350 for an ERROR packet
139///
140/// ```
141/// # use tftp_packet::ErrorCode;
142/// assert_eq!(ErrorCode::NotDefined as u16, 0u16);
143/// assert_eq!(ErrorCode::try_from(0), Ok(ErrorCode::NotDefined));
144/// ```
145#[derive(Debug, PartialEq, Clone)]
146pub enum ErrorCode {
147    NotDefined,
148    FileNotFound,
149    AccessViolation,
150    DiskFull,
151    IllegalOperation,
152    UnknownTransferId,
153    FileAlreadyExists,
154    NoSuchUser,
155}
156
157impl TryFrom<u16> for ErrorCode {
158    type Error = &'static str;
159
160    fn try_from(value: u16) -> Result<Self, Self::Error> {
161        Ok(match value {
162            0 => ErrorCode::NotDefined,
163            1 => ErrorCode::FileNotFound,
164            2 => ErrorCode::AccessViolation,
165            3 => ErrorCode::DiskFull,
166            4 => ErrorCode::IllegalOperation,
167            5 => ErrorCode::UnknownTransferId,
168            6 => ErrorCode::FileAlreadyExists,
169            7 => ErrorCode::NoSuchUser,
170            _ => Err("Invalid error code")?,
171        })
172    }
173}
174
175/// The tftp packet
176#[derive(Debug, PartialEq, Clone)]
177pub enum Packet {
178    RRQ {
179        filename: String,
180        mode: Mode,
181    },
182    WRQ {
183        filename: String,
184        mode: Mode,
185    },
186    DATA {
187        block_number: u16,
188        data: Vec<u8>,
189    },
190    ACK {
191        block_number: u16,
192    },
193    ERROR {
194        error_code: ErrorCode,
195        error_msg: String,
196    },
197}
198
199impl Packet {
200    /// Parse a packet from a byte array
201    ///
202    /// ```
203    /// # use tftp_packet::Packet;
204    /// # use tftp_packet::Mode;
205    /// let packet = &[0u8, 1, 67, 68, 69, 0, 0x6f, 0x63, 0x74, 0x65, 0x74, 0];
206    /// let packet = Packet::from_bytes(packet).unwrap();
207    /// assert_eq!(packet, Packet::RRQ {
208    ///    filename: "CDE".to_string(),
209    ///    mode: Mode::Octet,
210    /// });
211    /// ```
212    pub fn from_bytes(bytes: &[u8]) -> Result<Self, PacketError> {
213        let (bytes, opcode_bytes) = take_u16(bytes).map_err(|_| {
214            PacketError::InvalidOpcode("Error while parsing opcode. Opcode not a u15.".to_string())
215        })?;
216
217        let opcode = Opcode::try_from(opcode_bytes).map_err(|e| {
218            PacketError::InvalidOpcode(format!("Error while parsing opcode: {}", e))
219        })?;
220
221        match opcode {
222            Opcode::RRQ => {
223                let (filename, bytes) = parse_filename(bytes)?;
224                let (mode, _bytes) = parse_mode(bytes)?;
225
226                Ok(Packet::RRQ { filename, mode })
227            }
228            Opcode::WRQ => {
229                let (filename, bytes) = parse_filename(bytes)?;
230                let (mode, _bytes) = parse_mode(bytes)?;
231
232                Ok(Packet::WRQ { filename, mode })
233            }
234            Opcode::DATA => {
235                let (block_number, bytes) = parse_block_number(bytes)?;
236
237                let data = bytes.to_vec();
238
239                if data.len() > 512 {
240                    Err(PacketError::InvalidPacketLength(512))?
241                }
242
243                Ok(Packet::DATA { block_number, data })
244            }
245            Opcode::ACK => {
246                let (block_number, bytes) = parse_block_number(bytes)?;
247
248                if bytes.is_empty() {
249                    Ok(Packet::ACK { block_number })
250                } else {
251                    Err(PacketError::InvalidPacketLength(4))?
252                }
253            }
254            Opcode::ERROR => {
255                let (error_code, bytes) = parse_error_code(bytes)?;
256                let (error_msg, _bytes) = parse_error_message(bytes)?;
257
258                Ok(Packet::ERROR {
259                    error_code,
260                    error_msg,
261                })
262            }
263        }
264    }
265
266    /// Serialize a packet into a byte array
267    ///
268    /// ```
269    /// # use tftp_packet::Packet;
270    /// # use tftp_packet::Mode;
271    /// let packet = Packet::RRQ {
272    ///   filename: "CDE".to_string(),
273    ///   mode: Mode::Octet,
274    /// };
275    /// let packet = packet.to_bytes();
276    /// assert_eq!(packet, &[0u8, 1, 67, 68, 69, 0, 0x6f, 0x63, 0x74, 0x65, 0x74, 0]);
277    /// ```
278    pub fn to_bytes(self) -> Vec<u8> {
279        match self {
280            Packet::RRQ { filename, mode } => {
281                let mut bytes = vec![0u8, 1];
282                bytes.extend(filename.as_bytes());
283                bytes.push(0);
284                bytes.extend(Into::<&str>::into(&mode).as_bytes());
285                bytes.push(0);
286                bytes
287            }
288            Packet::WRQ { filename, mode } => {
289                let mut bytes = vec![0u8, 2];
290                bytes.extend(filename.as_bytes());
291                bytes.push(0);
292                bytes.extend(Into::<&str>::into(&mode).as_bytes());
293                bytes.push(0);
294                bytes
295            }
296            Packet::DATA { block_number, data } => {
297                let mut bytes = vec![0u8, 3];
298                bytes.extend(block_number.to_be_bytes().to_vec());
299                bytes.extend(data);
300                bytes
301            }
302            Packet::ACK { block_number } => {
303                let mut bytes = vec![0u8, 4];
304                bytes.extend(block_number.to_be_bytes().to_vec());
305                bytes
306            }
307            Packet::ERROR {
308                error_code,
309                error_msg,
310            } => {
311                let mut bytes = vec![0u8, 5];
312                bytes.extend((error_code as u16).to_be_bytes().to_vec());
313                bytes.extend(error_msg.as_bytes());
314                bytes.push(0);
315                bytes
316            }
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_opcode() {
327        assert_eq!(Ok(Opcode::RRQ), 1.try_into());
328        assert_eq!(Ok(Opcode::WRQ), 2.try_into());
329        assert_eq!(Ok(Opcode::DATA), 3.try_into());
330        assert_eq!(Ok(Opcode::ACK), 4.try_into());
331        assert_eq!(Ok(Opcode::ERROR), 5.try_into());
332    }
333
334    #[test]
335    fn test_rrq_parser() {
336        let packet = &[0u8, 1, 67, 68, 69, 0, 0x6f, 0x63, 0x74, 0x65, 0x74, 0];
337        let packet = Packet::from_bytes(packet).unwrap();
338        assert_eq!(
339            packet,
340            Packet::RRQ {
341                filename: "CDE".to_string(),
342                mode: Mode::Octet,
343            }
344        );
345    }
346
347    #[test]
348    fn test_wrq_parser() {
349        let packet = &[0u8, 2, 67, 68, 69, 0, 0x6f, 0x63, 0x74, 0x65, 0x74, 0];
350        let packet = Packet::from_bytes(packet).unwrap();
351        assert_eq!(
352            packet,
353            Packet::WRQ {
354                filename: "CDE".to_string(),
355                mode: Mode::Octet,
356            }
357        );
358    }
359
360    #[test]
361    fn test_data_parser() {
362        let packet = &[0u8, 3, 0, 42, 67, 68, 69];
363        let packet = Packet::from_bytes(packet).unwrap();
364        assert_eq!(
365            packet,
366            Packet::DATA {
367                block_number: 42,
368                data: vec![67, 68, 69],
369            }
370        );
371    }
372
373    #[test]
374    fn test_ack_parser() {
375        let packet = &[0u8, 4, 0, 42];
376        let packet = Packet::from_bytes(packet).unwrap();
377        assert_eq!(packet, Packet::ACK { block_number: 42 });
378    }
379
380    #[test]
381    fn test_error_parser() {
382        let packet = &[0u8, 5, 0, 2, 67, 68, 69, 0];
383        let packet = Packet::from_bytes(packet).unwrap();
384        assert_eq!(
385            packet,
386            Packet::ERROR {
387                error_code: ErrorCode::AccessViolation,
388                error_msg: "CDE".to_string()
389            }
390        );
391    }
392
393    #[test]
394    fn test_invalid_opcode() {
395        let packet = &[0u8, 6, 67, 68, 69, 0, 0x6f, 0x63, 0x74, 0x65, 0x74, 0];
396        assert!(matches!(
397            Packet::from_bytes(packet),
398            Err(PacketError::InvalidOpcode(..))
399        ))
400    }
401
402    #[test]
403    fn test_invalid_data_length() {
404        let mut packet = vec![0u8, 3, 0, 42];
405        packet.extend([69; 513].iter());
406        assert!(matches!(
407            Packet::from_bytes(&packet),
408            Err(PacketError::InvalidPacketLength(..))
409        ))
410    }
411
412    #[test]
413    fn test_invalid_mode() {
414        let packet = &[0u8, 1, 67, 68, 69, 0, 67, 0];
415        assert!(matches!(
416            Packet::from_bytes(packet),
417            Err(PacketError::InvalidPacket(..))
418        ))
419    }
420
421    #[test]
422    fn test_rrq_to_bytes() {
423        let packet = Packet::RRQ {
424            filename: "CDE".to_string(),
425            mode: Mode::Octet,
426        };
427        assert_eq!(
428            packet.to_bytes(),
429            vec![0u8, 1, 67, 68, 69, 0, 0x6f, 0x63, 0x74, 0x65, 0x74, 0]
430        );
431    }
432
433    #[test]
434    fn test_wrq_to_bytes() {
435        let packet = Packet::WRQ {
436            filename: "CDE".to_string(),
437            mode: Mode::Octet,
438        };
439        assert_eq!(
440            packet.to_bytes(),
441            vec![0u8, 2, 67, 68, 69, 0, 0x6f, 0x63, 0x74, 0x65, 0x74, 0]
442        );
443    }
444
445    #[test]
446    fn test_data_to_bytes() {
447        let packet = Packet::DATA {
448            block_number: 42,
449            data: vec![67, 68, 69],
450        };
451        assert_eq!(packet.to_bytes(), vec![0u8, 3, 0, 42, 67, 68, 69]);
452    }
453
454    #[test]
455    fn test_ack_to_bytes() {
456        let packet = Packet::ACK { block_number: 42 };
457        assert_eq!(packet.to_bytes(), vec![0u8, 4, 0, 42]);
458    }
459
460    #[test]
461    fn test_error_to_bytes() {
462        let packet = Packet::ERROR {
463            error_code: ErrorCode::AccessViolation,
464            error_msg: "CDE".to_string(),
465        };
466        assert_eq!(packet.to_bytes(), vec![0u8, 5, 0, 2, 67, 68, 69, 0]);
467    }
468}