Skip to main content

stackforge_core/layer/udp/
mod.rs

1//! UDP (User Datagram Protocol) layer implementation.
2//!
3//! This module provides types and functions for working with UDP packets,
4//! including parsing, field access, and checksum calculation.
5
6pub mod builder;
7pub mod checksum;
8
9// Re-export builder
10pub use builder::UdpBuilder;
11
12// Re-export checksum functions
13pub use checksum::{
14    udp_checksum_ipv4, udp_checksum_ipv6, verify_udp_checksum_ipv4, verify_udp_checksum_ipv6,
15};
16
17use crate::layer::field::{FieldDesc, FieldError, FieldType, FieldValue};
18use crate::layer::{Layer, LayerIndex, LayerKind};
19
20/// UDP header length (8 bytes fixed).
21pub const UDP_HEADER_LEN: usize = 8;
22
23/// Field offsets within the UDP header.
24pub mod offsets {
25    pub const SRC_PORT: usize = 0;
26    pub const DST_PORT: usize = 2;
27    pub const LENGTH: usize = 4;
28    pub const CHECKSUM: usize = 6;
29}
30
31/// UDP field descriptors for dynamic access.
32pub static FIELDS: &[FieldDesc] = &[
33    FieldDesc::new("sport", offsets::SRC_PORT, 2, FieldType::U16),
34    FieldDesc::new("dport", offsets::DST_PORT, 2, FieldType::U16),
35    FieldDesc::new("len", offsets::LENGTH, 2, FieldType::U16),
36    FieldDesc::new("chksum", offsets::CHECKSUM, 2, FieldType::U16),
37];
38
39/// UDP layer representation.
40///
41/// UDP is a simple, connectionless transport protocol defined in RFC 768.
42/// The header is 8 bytes:
43///
44/// ```text
45///  0                   1                   2                   3
46///  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
47/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
48/// |          Source Port          |       Destination Port        |
49/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
50/// |            Length             |           Checksum            |
51/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
52/// |                             Data                              |
53/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
54/// ```
55#[derive(Debug, Clone)]
56pub struct UdpLayer {
57    pub index: LayerIndex,
58}
59
60impl UdpLayer {
61    /// Create a new UDP layer from a layer index.
62    #[must_use]
63    pub fn new(index: LayerIndex) -> Self {
64        Self { index }
65    }
66
67    /// Get the source port.
68    pub fn src_port(&self, buf: &[u8]) -> Result<u16, FieldError> {
69        let slice = self.index.slice(buf);
70        if slice.len() < offsets::SRC_PORT + 2 {
71            return Err(FieldError::BufferTooShort {
72                offset: self.index.start + offsets::SRC_PORT,
73                need: 2,
74                have: slice.len().saturating_sub(offsets::SRC_PORT),
75            });
76        }
77        Ok(u16::from_be_bytes([
78            slice[offsets::SRC_PORT],
79            slice[offsets::SRC_PORT + 1],
80        ]))
81    }
82
83    /// Get the destination port.
84    pub fn dst_port(&self, buf: &[u8]) -> Result<u16, FieldError> {
85        let slice = self.index.slice(buf);
86        if slice.len() < offsets::DST_PORT + 2 {
87            return Err(FieldError::BufferTooShort {
88                offset: self.index.start + offsets::DST_PORT,
89                need: 2,
90                have: slice.len().saturating_sub(offsets::DST_PORT),
91            });
92        }
93        Ok(u16::from_be_bytes([
94            slice[offsets::DST_PORT],
95            slice[offsets::DST_PORT + 1],
96        ]))
97    }
98
99    /// Get the length field (header + data length in bytes).
100    pub fn length(&self, buf: &[u8]) -> Result<u16, FieldError> {
101        let slice = self.index.slice(buf);
102        if slice.len() < offsets::LENGTH + 2 {
103            return Err(FieldError::BufferTooShort {
104                offset: self.index.start + offsets::LENGTH,
105                need: 2,
106                have: slice.len().saturating_sub(offsets::LENGTH),
107            });
108        }
109        Ok(u16::from_be_bytes([
110            slice[offsets::LENGTH],
111            slice[offsets::LENGTH + 1],
112        ]))
113    }
114
115    /// Get the checksum.
116    pub fn checksum(&self, buf: &[u8]) -> Result<u16, FieldError> {
117        let slice = self.index.slice(buf);
118        if slice.len() < offsets::CHECKSUM + 2 {
119            return Err(FieldError::BufferTooShort {
120                offset: self.index.start + offsets::CHECKSUM,
121                need: 2,
122                have: slice.len().saturating_sub(offsets::CHECKSUM),
123            });
124        }
125        Ok(u16::from_be_bytes([
126            slice[offsets::CHECKSUM],
127            slice[offsets::CHECKSUM + 1],
128        ]))
129    }
130
131    /// Set the source port.
132    pub fn set_src_port(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
133        let start = self.index.start + offsets::SRC_PORT;
134        if buf.len() < start + 2 {
135            return Err(FieldError::BufferTooShort {
136                offset: start,
137                need: 2,
138                have: buf.len().saturating_sub(start),
139            });
140        }
141        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
142        Ok(())
143    }
144
145    /// Set the destination port.
146    pub fn set_dst_port(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
147        let start = self.index.start + offsets::DST_PORT;
148        if buf.len() < start + 2 {
149            return Err(FieldError::BufferTooShort {
150                offset: start,
151                need: 2,
152                have: buf.len().saturating_sub(start),
153            });
154        }
155        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
156        Ok(())
157    }
158
159    /// Set the length field.
160    pub fn set_length(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
161        let start = self.index.start + offsets::LENGTH;
162        if buf.len() < start + 2 {
163            return Err(FieldError::BufferTooShort {
164                offset: start,
165                need: 2,
166                have: buf.len().saturating_sub(start),
167            });
168        }
169        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
170        Ok(())
171    }
172
173    /// Set the checksum.
174    pub fn set_checksum(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
175        let start = self.index.start + offsets::CHECKSUM;
176        if buf.len() < start + 2 {
177            return Err(FieldError::BufferTooShort {
178                offset: start,
179                need: 2,
180                have: buf.len().saturating_sub(start),
181            });
182        }
183        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
184        Ok(())
185    }
186
187    /// Generate a summary string for display.
188    #[must_use]
189    pub fn summary(&self, buf: &[u8]) -> String {
190        let slice = self.index.slice(buf);
191        if slice.len() >= 4 {
192            let src_port = u16::from_be_bytes([slice[0], slice[1]]);
193            let dst_port = u16::from_be_bytes([slice[2], slice[3]]);
194            format!("UDP {src_port} > {dst_port}")
195        } else {
196            "UDP".to_string()
197        }
198    }
199
200    /// Get the UDP header length (always 8 bytes).
201    #[must_use]
202    pub fn header_len(&self, _buf: &[u8]) -> usize {
203        UDP_HEADER_LEN
204    }
205
206    /// Get field names for this layer.
207    #[must_use]
208    pub fn field_names(&self) -> &'static [&'static str] {
209        &["sport", "dport", "len", "chksum"]
210    }
211
212    /// Get a field value by name.
213    pub fn get_field(&self, buf: &[u8], name: &str) -> Option<Result<FieldValue, FieldError>> {
214        match name {
215            "sport" => Some(self.src_port(buf).map(FieldValue::U16)),
216            "dport" => Some(self.dst_port(buf).map(FieldValue::U16)),
217            "len" => Some(self.length(buf).map(FieldValue::U16)),
218            "chksum" => Some(self.checksum(buf).map(FieldValue::U16)),
219            _ => None,
220        }
221    }
222
223    /// Set a field value by name.
224    pub fn set_field(
225        &self,
226        buf: &mut [u8],
227        name: &str,
228        value: FieldValue,
229    ) -> Option<Result<(), FieldError>> {
230        match name {
231            "sport" => {
232                if let FieldValue::U16(v) = value {
233                    Some(self.set_src_port(buf, v))
234                } else {
235                    Some(Err(FieldError::InvalidValue(format!(
236                        "sport: expected U16, got {value:?}"
237                    ))))
238                }
239            },
240            "dport" => {
241                if let FieldValue::U16(v) = value {
242                    Some(self.set_dst_port(buf, v))
243                } else {
244                    Some(Err(FieldError::InvalidValue(format!(
245                        "dport: expected U16, got {value:?}"
246                    ))))
247                }
248            },
249            "len" => {
250                if let FieldValue::U16(v) = value {
251                    Some(self.set_length(buf, v))
252                } else {
253                    Some(Err(FieldError::InvalidValue(format!(
254                        "len: expected U16, got {value:?}"
255                    ))))
256                }
257            },
258            "chksum" => {
259                if let FieldValue::U16(v) = value {
260                    Some(self.set_checksum(buf, v))
261                } else {
262                    Some(Err(FieldError::InvalidValue(format!(
263                        "chksum: expected U16, got {value:?}"
264                    ))))
265                }
266            },
267            _ => None,
268        }
269    }
270}
271
272impl Layer for UdpLayer {
273    fn kind(&self) -> LayerKind {
274        LayerKind::Udp
275    }
276
277    fn summary(&self, data: &[u8]) -> String {
278        self.summary(data)
279    }
280
281    fn header_len(&self, data: &[u8]) -> usize {
282        self.header_len(data)
283    }
284
285    fn hashret(&self, _data: &[u8]) -> Vec<u8> {
286        // UDP is stateless, return empty hash
287        vec![]
288    }
289
290    fn answers(&self, data: &[u8], other: &Self, other_data: &[u8]) -> bool {
291        // UDP answers if destination port matches other's source port
292        if let (Ok(self_dport), Ok(other_sport)) = (self.dst_port(data), other.src_port(other_data))
293        {
294            self_dport == other_sport
295        } else {
296            false
297        }
298    }
299
300    fn extract_padding<'a>(&self, data: &'a [u8]) -> (&'a [u8], &'a [u8]) {
301        // UDP length field includes header + payload
302        if let Ok(udp_len) = self.length(data) {
303            let udp_len = udp_len as usize;
304            let start = self.index.start;
305            let available = data.len().saturating_sub(start);
306
307            if udp_len >= UDP_HEADER_LEN && udp_len <= available {
308                let end = start + udp_len;
309                return (&data[start..end], &data[end..]);
310            }
311        }
312        // If length field is invalid, assume no padding
313        let payload = self.index.payload(data);
314        (
315            &data[self.index.start..self.index.start + self.index.len() + payload.len()],
316            &[],
317        )
318    }
319
320    fn field_names(&self) -> &'static [&'static str] {
321        self.field_names()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_udp_parse() {
331        // UDP packet: sport=12345, dport=53, len=20, checksum=0x1234
332        let data = [
333            0x30, 0x39, // sport = 12345
334            0x00, 0x35, // dport = 53 (DNS)
335            0x00, 0x14, // len = 20
336            0x12, 0x34, // checksum
337            // 12 bytes of payload
338            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
339        ];
340
341        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
342        let udp = UdpLayer::new(index);
343
344        assert_eq!(udp.src_port(&data).unwrap(), 12345);
345        assert_eq!(udp.dst_port(&data).unwrap(), 53);
346        assert_eq!(udp.length(&data).unwrap(), 20);
347        assert_eq!(udp.checksum(&data).unwrap(), 0x1234);
348    }
349
350    #[test]
351    fn test_udp_summary() {
352        let data = [
353            0x04, 0xd2, // sport = 1234
354            0x00, 0x50, // dport = 80 (HTTP)
355            0x00, 0x08, // len = 8 (header only)
356            0x00, 0x00, // checksum
357        ];
358
359        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
360        let udp = UdpLayer::new(index);
361
362        let summary = udp.summary(&data);
363        assert_eq!(summary, "UDP 1234 > 80");
364    }
365
366    #[test]
367    fn test_udp_set_fields() {
368        let mut data = vec![0u8; 8];
369        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
370        let udp = UdpLayer::new(index);
371
372        udp.set_src_port(&mut data, 5000).unwrap();
373        udp.set_dst_port(&mut data, 6000).unwrap();
374        udp.set_length(&mut data, 100).unwrap();
375        udp.set_checksum(&mut data, 0xABCD).unwrap();
376
377        assert_eq!(udp.src_port(&data).unwrap(), 5000);
378        assert_eq!(udp.dst_port(&data).unwrap(), 6000);
379        assert_eq!(udp.length(&data).unwrap(), 100);
380        assert_eq!(udp.checksum(&data).unwrap(), 0xABCD);
381    }
382
383    #[test]
384    fn test_udp_extract_padding() {
385        // UDP packet with padding
386        let data = [
387            0x30, 0x39, // sport
388            0x00, 0x35, // dport
389            0x00, 0x0c, // len = 12 (8 header + 4 payload)
390            0x00, 0x00, // checksum
391            0x01, 0x02, 0x03, 0x04, // 4 bytes payload
392            0xff, 0xff, 0xff, 0xff, // 4 bytes padding
393        ];
394
395        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
396        let udp = UdpLayer::new(index);
397
398        let (udp_data, padding) = udp.extract_padding(&data);
399        assert_eq!(udp_data.len(), 12); // 8 header + 4 payload
400        assert_eq!(padding.len(), 4); // 4 bytes padding
401    }
402}