zero_packet/transport/
tcp.rs

1use crate::network::checksum::internet_checksum;
2use core::fmt;
3
4/// The minimum length of a TCP header in bytes.
5pub const TCP_MIN_HEADER_LENGTH: usize = 20;
6
7/// Writes TCP header fields.
8pub struct TcpWriter<'a> {
9    pub bytes: &'a mut [u8],
10}
11
12impl<'a> TcpWriter<'a> {
13    /// Creates a new `TcpWriter` from the given slice.
14    #[inline]
15    pub fn new(bytes: &'a mut [u8]) -> Result<Self, &'static str> {
16        if bytes.len() < TCP_MIN_HEADER_LENGTH {
17            return Err("Slice is too short to contain a TCP header.");
18        }
19
20        Ok(Self { bytes })
21    }
22
23    /// Returns the header length in bytes.
24    #[inline]
25    pub fn header_len(&self) -> usize {
26        (self.bytes[12] >> 4) as usize * 4
27    }
28
29    /// Returns the total length of the packet in bytes.
30    #[inline]
31    pub fn packet_len(&self) -> usize {
32        self.bytes.len()
33    }
34
35    /// Sets the source port field.
36    #[inline]
37    pub fn set_src_port(&mut self, src_port: u16) {
38        self.bytes[0] = (src_port >> 8) as u8;
39        self.bytes[1] = (src_port & 0xFF) as u8;
40    }
41
42    /// Sets the destination port field.
43    #[inline]
44    pub fn set_dest_port(&mut self, dest_port: u16) {
45        self.bytes[2] = (dest_port >> 8) as u8;
46        self.bytes[3] = (dest_port & 0xFF) as u8;
47    }
48
49    /// Sets the sequence number field.
50    #[inline]
51    pub fn set_sequence_number(&mut self, sequence_number: u32) {
52        self.bytes[4] = (sequence_number >> 24) as u8;
53        self.bytes[5] = (sequence_number >> 16) as u8;
54        self.bytes[6] = (sequence_number >> 8) as u8;
55        self.bytes[7] = (sequence_number & 0xFF) as u8;
56    }
57
58    /// Sets the acknowledgment number field.
59    #[inline]
60    pub fn set_ack_number(&mut self, acknowledgment_number: u32) {
61        self.bytes[8] = (acknowledgment_number >> 24) as u8;
62        self.bytes[9] = (acknowledgment_number >> 16) as u8;
63        self.bytes[10] = (acknowledgment_number >> 8) as u8;
64        self.bytes[11] = (acknowledgment_number & 0xFF) as u8;
65    }
66
67    /// Sets the data offset field.
68    #[inline]
69    pub fn set_data_offset(&mut self, data_offset: u8) {
70        self.bytes[12] = (data_offset << 4) | (self.bytes[12] & 0x0f);
71    }
72
73    /// Sets the reserved field.
74    #[inline]
75    pub fn set_reserved(&mut self, reserved: u8) {
76        self.bytes[12] = (self.bytes[12] & 0xf0) | (reserved & 0x0f);
77    }
78
79    /// Sets the flags field.
80    #[inline]
81    pub fn set_flags(&mut self, flags: u8) {
82        self.bytes[13] = flags;
83    }
84
85    /// Sets the window size field.
86    #[inline]
87    pub fn set_window_size(&mut self, window_size: u16) {
88        self.bytes[14] = (window_size >> 8) as u8;
89        self.bytes[15] = (window_size & 0xFF) as u8;
90    }
91
92    /// Sets the urgent pointer field.
93    #[inline]
94    pub fn set_urgent_pointer(&mut self, urgent_pointer: u16) {
95        self.bytes[18] = (urgent_pointer >> 8) as u8;
96        self.bytes[19] = (urgent_pointer & 0xFF) as u8;
97    }
98
99    /// Sets the payload field.
100    ///
101    /// The `PacketBuilder` sets the payload before the checksum.
102    ///
103    /// That is, because the checksum is invalidated if a payload is set after it.
104    #[inline]
105    pub fn set_payload(&mut self, payload: &[u8]) -> Result<(), &'static str> {
106        let start = self.header_len();
107        let payload_len = payload.len();
108
109        if self.packet_len() - start < payload_len {
110            return Err("Payload is too large to fit in the TCP packet.");
111        }
112
113        let end = start + payload_len;
114        self.bytes[start..end].copy_from_slice(payload);
115
116        Ok(())
117    }
118
119    /// Calculates the checksum field for error checking.
120    ///
121    /// Includes the TCP header, payload and IPv4 pseudo header.
122    #[inline]
123    pub fn set_checksum(&mut self, pseudo_sum: u32) {
124        self.bytes[16] = 0;
125        self.bytes[17] = 0;
126        let checksum = internet_checksum(self.bytes, pseudo_sum);
127        self.bytes[16] = (checksum >> 8) as u8;
128        self.bytes[17] = (checksum & 0xff) as u8;
129    }
130}
131
132/// Reads TCP header fields.
133#[derive(PartialEq)]
134pub struct TcpReader<'a> {
135    pub bytes: &'a [u8],
136}
137
138impl<'a> TcpReader<'a> {
139    /// Creates a new `TcpReader` from the given slice.
140    #[inline]
141    pub fn new(bytes: &'a [u8]) -> Result<Self, &'static str> {
142        if bytes.len() < TCP_MIN_HEADER_LENGTH {
143            return Err("Slice is too short to contain a TCP header.");
144        }
145
146        Ok(Self { bytes })
147    }
148
149    /// Returns the source port field.
150    #[inline]
151    pub fn src_port(&self) -> u16 {
152        ((self.bytes[0] as u16) << 8) | (self.bytes[1] as u16)
153    }
154
155    /// Returns the destination port field.
156    #[inline]
157    pub fn dest_port(&self) -> u16 {
158        ((self.bytes[2] as u16) << 8) | (self.bytes[3] as u16)
159    }
160
161    /// Returns the sequence number field.
162    #[inline]
163    pub fn sequence_number(&self) -> u32 {
164        ((self.bytes[4] as u32) << 24)
165            | ((self.bytes[5] as u32) << 16)
166            | ((self.bytes[6] as u32) << 8)
167            | (self.bytes[7] as u32)
168    }
169
170    /// Returns the acknowledgment number field.
171    #[inline]
172    pub fn ack_number(&self) -> u32 {
173        ((self.bytes[8] as u32) << 24)
174            | ((self.bytes[9] as u32) << 16)
175            | ((self.bytes[10] as u32) << 8)
176            | (self.bytes[11] as u32)
177    }
178
179    /// Returns the data offset field.
180    #[inline]
181    pub fn data_offset(&self) -> u8 {
182        self.bytes[12] >> 4
183    }
184
185    /// Returns the reserved field.
186    #[inline]
187    pub fn reserved(&self) -> u8 {
188        self.bytes[12] & 0x0F
189    }
190
191    /// Returns the flags field.
192    #[inline]
193    pub fn flags(&self) -> u8 {
194        self.bytes[13]
195    }
196
197    /// Returns the window size field.
198    #[inline]
199    pub fn window_size(&self) -> u16 {
200        ((self.bytes[14] as u16) << 8) | (self.bytes[15] as u16)
201    }
202
203    /// Returns the checksum field.
204    #[inline]
205    pub fn checksum(&self) -> u16 {
206        ((self.bytes[16] as u16) << 8) | (self.bytes[17] as u16)
207    }
208
209    /// Returns the urgent pointer field.
210    #[inline]
211    pub fn urgent_pointer(&self) -> u16 {
212        ((self.bytes[18] as u16) << 8) | (self.bytes[19] as u16)
213    }
214
215    /// Returns the header length in bytes by multiplying the data offset.
216    #[inline]
217    pub fn header_len(&self) -> usize {
218        self.data_offset() as usize * 4
219    }
220
221    /// Returns a reference to the header.
222    #[inline]
223    pub fn header(&self) -> Result<&'a [u8], &'static str> {
224        let end = self.header_len();
225
226        if end > self.bytes.len() {
227            return Err("Indicated TCP header length exceeds the allocated buffer.");
228        }
229
230        Ok(&self.bytes[..end])
231    }
232
233    /// Returns a reference to the payload.
234    #[inline]
235    pub fn payload(&self) -> Result<&'a [u8], &'static str> {
236        let start = self.header_len();
237
238        if start > self.bytes.len() {
239            return Err("Indicated TCP header length exceeds the allocated buffer.");
240        }
241
242        Ok(&self.bytes[start..])
243    }
244}
245
246impl fmt::Debug for TcpReader<'_> {
247    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248        f.debug_struct("TcpSegment")
249            .field("src_port", &self.src_port())
250            .field("dest_port", &self.dest_port())
251            .field("sequence_number", &self.sequence_number())
252            .field("acknowledgment_number", &self.ack_number())
253            .field("data_offset", &self.data_offset())
254            .field("reserved", &self.reserved())
255            .field("flags", &self.flags())
256            .field("window_size", &self.window_size())
257            .field("checksum", &self.checksum())
258            .field("urgent_pointer", &self.urgent_pointer())
259            .finish()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::network::checksum::pseudo_header;
267
268    #[test]
269    fn getters_and_setters() {
270        // Raw packet.
271        let mut bytes = [0u8; 40];
272
273        // Random values.
274        let src_ip = [192, 168, 1, 1];
275        let dest_ip = [192, 168, 1, 2];
276        let src_port = 12345;
277        let dest_port = 54321;
278        let sequence_number = 12345;
279        let acknowledgment_number = 54321;
280        let reserved = 10;
281        let data_offset = 5;
282        let flags = 2;
283        let window_size = 1024;
284        let urgent_pointer = 5;
285
286        // Create a TCP packet writer.
287        let mut writer = TcpWriter::new(&mut bytes).unwrap();
288
289        // Set the TCP header fields.
290        writer.set_src_port(src_port);
291        writer.set_dest_port(dest_port);
292        writer.set_sequence_number(sequence_number);
293        writer.set_ack_number(acknowledgment_number);
294        writer.set_reserved(reserved);
295        writer.set_data_offset(data_offset);
296        writer.set_flags(flags);
297        writer.set_window_size(window_size);
298        writer.set_urgent_pointer(urgent_pointer);
299
300        // Set the checksum including the pseudo header.
301        let pseudo_sum = pseudo_header(&src_ip, &dest_ip, 6, writer.packet_len());
302        writer.set_checksum(pseudo_sum);
303
304        // Create a TCP packet reader.
305        let reader = TcpReader::new(&bytes).unwrap();
306
307        // Ensure the fields are set and retrieved correctly.
308        assert_eq!(reader.src_port(), src_port);
309        assert_eq!(reader.dest_port(), dest_port);
310        assert_eq!(reader.sequence_number(), sequence_number);
311        assert_eq!(reader.ack_number(), acknowledgment_number);
312        assert_eq!(reader.reserved(), reserved);
313        assert_eq!(reader.data_offset(), data_offset);
314        assert_eq!(reader.flags(), flags);
315        assert_eq!(reader.window_size(), window_size);
316        assert_eq!(reader.urgent_pointer(), urgent_pointer);
317    }
318}