Skip to main content

stackforge_core/layer/udp/
mod.rs

1//! UDP (User Datagram Protocol) layer implementation.
2//!
3//! This module provides types and functions for working with UDP packets,
4//! including parsing, field access, and checksum calculation.
5
6pub mod builder;
7pub mod checksum;
8
9// Re-export builder
10pub use builder::UdpBuilder;
11
12// Re-export checksum functions
13pub use checksum::{
14    udp_checksum_ipv4, udp_checksum_ipv6, verify_udp_checksum_ipv4, verify_udp_checksum_ipv6,
15};
16
17use crate::layer::field::{FieldDesc, FieldError, FieldType, FieldValue};
18use crate::layer::{Layer, LayerIndex, LayerKind};
19
20/// UDP header length (8 bytes fixed).
21pub const UDP_HEADER_LEN: usize = 8;
22
23/// Field offsets within the UDP header.
24pub mod offsets {
25    pub const SRC_PORT: usize = 0;
26    pub const DST_PORT: usize = 2;
27    pub const LENGTH: usize = 4;
28    pub const CHECKSUM: usize = 6;
29}
30
31/// UDP field descriptors for dynamic access.
32pub static FIELDS: &[FieldDesc] = &[
33    FieldDesc::new("sport", offsets::SRC_PORT, 2, FieldType::U16),
34    FieldDesc::new("dport", offsets::DST_PORT, 2, FieldType::U16),
35    FieldDesc::new("len", offsets::LENGTH, 2, FieldType::U16),
36    FieldDesc::new("chksum", offsets::CHECKSUM, 2, FieldType::U16),
37];
38
39/// UDP layer representation.
40///
41/// UDP is a simple, connectionless transport protocol defined in RFC 768.
42/// The header is 8 bytes:
43///
44/// ```text
45///  0                   1                   2                   3
46///  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
47/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
48/// |          Source Port          |       Destination Port        |
49/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
50/// |            Length             |           Checksum            |
51/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
52/// |                             Data                              |
53/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
54/// ```
55#[derive(Debug, Clone)]
56pub struct UdpLayer {
57    pub index: LayerIndex,
58}
59
60impl UdpLayer {
61    /// Create a new UDP layer from a layer index.
62    pub fn new(index: LayerIndex) -> Self {
63        Self { index }
64    }
65
66    /// Get the source port.
67    pub fn src_port(&self, buf: &[u8]) -> Result<u16, FieldError> {
68        let slice = self.index.slice(buf);
69        if slice.len() < offsets::SRC_PORT + 2 {
70            return Err(FieldError::BufferTooShort {
71                offset: self.index.start + offsets::SRC_PORT,
72                need: 2,
73                have: slice.len().saturating_sub(offsets::SRC_PORT),
74            });
75        }
76        Ok(u16::from_be_bytes([
77            slice[offsets::SRC_PORT],
78            slice[offsets::SRC_PORT + 1],
79        ]))
80    }
81
82    /// Get the destination port.
83    pub fn dst_port(&self, buf: &[u8]) -> Result<u16, FieldError> {
84        let slice = self.index.slice(buf);
85        if slice.len() < offsets::DST_PORT + 2 {
86            return Err(FieldError::BufferTooShort {
87                offset: self.index.start + offsets::DST_PORT,
88                need: 2,
89                have: slice.len().saturating_sub(offsets::DST_PORT),
90            });
91        }
92        Ok(u16::from_be_bytes([
93            slice[offsets::DST_PORT],
94            slice[offsets::DST_PORT + 1],
95        ]))
96    }
97
98    /// Get the length field (header + data length in bytes).
99    pub fn length(&self, buf: &[u8]) -> Result<u16, FieldError> {
100        let slice = self.index.slice(buf);
101        if slice.len() < offsets::LENGTH + 2 {
102            return Err(FieldError::BufferTooShort {
103                offset: self.index.start + offsets::LENGTH,
104                need: 2,
105                have: slice.len().saturating_sub(offsets::LENGTH),
106            });
107        }
108        Ok(u16::from_be_bytes([
109            slice[offsets::LENGTH],
110            slice[offsets::LENGTH + 1],
111        ]))
112    }
113
114    /// Get the checksum.
115    pub fn checksum(&self, buf: &[u8]) -> Result<u16, FieldError> {
116        let slice = self.index.slice(buf);
117        if slice.len() < offsets::CHECKSUM + 2 {
118            return Err(FieldError::BufferTooShort {
119                offset: self.index.start + offsets::CHECKSUM,
120                need: 2,
121                have: slice.len().saturating_sub(offsets::CHECKSUM),
122            });
123        }
124        Ok(u16::from_be_bytes([
125            slice[offsets::CHECKSUM],
126            slice[offsets::CHECKSUM + 1],
127        ]))
128    }
129
130    /// Set the source port.
131    pub fn set_src_port(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
132        let start = self.index.start + offsets::SRC_PORT;
133        if buf.len() < start + 2 {
134            return Err(FieldError::BufferTooShort {
135                offset: start,
136                need: 2,
137                have: buf.len().saturating_sub(start),
138            });
139        }
140        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
141        Ok(())
142    }
143
144    /// Set the destination port.
145    pub fn set_dst_port(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
146        let start = self.index.start + offsets::DST_PORT;
147        if buf.len() < start + 2 {
148            return Err(FieldError::BufferTooShort {
149                offset: start,
150                need: 2,
151                have: buf.len().saturating_sub(start),
152            });
153        }
154        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
155        Ok(())
156    }
157
158    /// Set the length field.
159    pub fn set_length(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
160        let start = self.index.start + offsets::LENGTH;
161        if buf.len() < start + 2 {
162            return Err(FieldError::BufferTooShort {
163                offset: start,
164                need: 2,
165                have: buf.len().saturating_sub(start),
166            });
167        }
168        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
169        Ok(())
170    }
171
172    /// Set the checksum.
173    pub fn set_checksum(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
174        let start = self.index.start + offsets::CHECKSUM;
175        if buf.len() < start + 2 {
176            return Err(FieldError::BufferTooShort {
177                offset: start,
178                need: 2,
179                have: buf.len().saturating_sub(start),
180            });
181        }
182        buf[start..start + 2].copy_from_slice(&value.to_be_bytes());
183        Ok(())
184    }
185
186    /// Generate a summary string for display.
187    pub fn summary(&self, buf: &[u8]) -> String {
188        let slice = self.index.slice(buf);
189        if slice.len() >= 4 {
190            let src_port = u16::from_be_bytes([slice[0], slice[1]]);
191            let dst_port = u16::from_be_bytes([slice[2], slice[3]]);
192            format!("UDP {} > {}", src_port, dst_port)
193        } else {
194            "UDP".to_string()
195        }
196    }
197
198    /// Get the UDP header length (always 8 bytes).
199    pub fn header_len(&self, _buf: &[u8]) -> usize {
200        UDP_HEADER_LEN
201    }
202
203    /// Get field names for this layer.
204    pub fn field_names(&self) -> &'static [&'static str] {
205        &["sport", "dport", "len", "chksum"]
206    }
207
208    /// Get a field value by name.
209    pub fn get_field(&self, buf: &[u8], name: &str) -> Option<Result<FieldValue, FieldError>> {
210        match name {
211            "sport" => Some(self.src_port(buf).map(FieldValue::U16)),
212            "dport" => Some(self.dst_port(buf).map(FieldValue::U16)),
213            "len" => Some(self.length(buf).map(FieldValue::U16)),
214            "chksum" => Some(self.checksum(buf).map(FieldValue::U16)),
215            _ => None,
216        }
217    }
218
219    /// Set a field value by name.
220    pub fn set_field(
221        &self,
222        buf: &mut [u8],
223        name: &str,
224        value: FieldValue,
225    ) -> Option<Result<(), FieldError>> {
226        match name {
227            "sport" => {
228                if let FieldValue::U16(v) = value {
229                    Some(self.set_src_port(buf, v))
230                } else {
231                    Some(Err(FieldError::InvalidValue(format!(
232                        "sport: expected U16, got {:?}",
233                        value
234                    ))))
235                }
236            }
237            "dport" => {
238                if let FieldValue::U16(v) = value {
239                    Some(self.set_dst_port(buf, v))
240                } else {
241                    Some(Err(FieldError::InvalidValue(format!(
242                        "dport: expected U16, got {:?}",
243                        value
244                    ))))
245                }
246            }
247            "len" => {
248                if let FieldValue::U16(v) = value {
249                    Some(self.set_length(buf, v))
250                } else {
251                    Some(Err(FieldError::InvalidValue(format!(
252                        "len: expected U16, got {:?}",
253                        value
254                    ))))
255                }
256            }
257            "chksum" => {
258                if let FieldValue::U16(v) = value {
259                    Some(self.set_checksum(buf, v))
260                } else {
261                    Some(Err(FieldError::InvalidValue(format!(
262                        "chksum: expected U16, got {:?}",
263                        value
264                    ))))
265                }
266            }
267            _ => None,
268        }
269    }
270}
271
272impl Layer for UdpLayer {
273    fn kind(&self) -> LayerKind {
274        LayerKind::Udp
275    }
276
277    fn summary(&self, data: &[u8]) -> String {
278        self.summary(data)
279    }
280
281    fn header_len(&self, data: &[u8]) -> usize {
282        self.header_len(data)
283    }
284
285    fn hashret(&self, _data: &[u8]) -> Vec<u8> {
286        // UDP is stateless, return empty hash
287        vec![]
288    }
289
290    fn answers(&self, data: &[u8], other: &Self, other_data: &[u8]) -> bool {
291        // UDP answers if destination port matches other's source port
292        if let (Ok(self_dport), Ok(other_sport)) = (self.dst_port(data), other.src_port(other_data))
293        {
294            self_dport == other_sport
295        } else {
296            false
297        }
298    }
299
300    fn extract_padding<'a>(&self, data: &'a [u8]) -> (&'a [u8], &'a [u8]) {
301        // UDP length field includes header + payload
302        if let Ok(udp_len) = self.length(data) {
303            let udp_len = udp_len as usize;
304            let start = self.index.start;
305            let available = data.len().saturating_sub(start);
306
307            if udp_len >= UDP_HEADER_LEN && udp_len <= available {
308                let end = start + udp_len;
309                return (&data[start..end], &data[end..]);
310            }
311        }
312        // If length field is invalid, assume no padding
313        let payload = self.index.payload(data);
314        (
315            &data[self.index.start..self.index.start + self.index.len() + payload.len()],
316            &[],
317        )
318    }
319
320    fn field_names(&self) -> &'static [&'static str] {
321        self.field_names()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_udp_parse() {
331        // UDP packet: sport=12345, dport=53, len=20, checksum=0x1234
332        let data = [
333            0x30, 0x39, // sport = 12345
334            0x00, 0x35, // dport = 53 (DNS)
335            0x00, 0x14, // len = 20
336            0x12, 0x34, // checksum
337            // 12 bytes of payload
338            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
339        ];
340
341        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
342        let udp = UdpLayer::new(index);
343
344        assert_eq!(udp.src_port(&data).unwrap(), 12345);
345        assert_eq!(udp.dst_port(&data).unwrap(), 53);
346        assert_eq!(udp.length(&data).unwrap(), 20);
347        assert_eq!(udp.checksum(&data).unwrap(), 0x1234);
348    }
349
350    #[test]
351    fn test_udp_summary() {
352        let data = [
353            0x04, 0xd2, // sport = 1234
354            0x00, 0x50, // dport = 80 (HTTP)
355            0x00, 0x08, // len = 8 (header only)
356            0x00, 0x00, // checksum
357        ];
358
359        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
360        let udp = UdpLayer::new(index);
361
362        let summary = udp.summary(&data);
363        assert_eq!(summary, "UDP 1234 > 80");
364    }
365
366    #[test]
367    fn test_udp_set_fields() {
368        let mut data = vec![0u8; 8];
369        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
370        let udp = UdpLayer::new(index);
371
372        udp.set_src_port(&mut data, 5000).unwrap();
373        udp.set_dst_port(&mut data, 6000).unwrap();
374        udp.set_length(&mut data, 100).unwrap();
375        udp.set_checksum(&mut data, 0xABCD).unwrap();
376
377        assert_eq!(udp.src_port(&data).unwrap(), 5000);
378        assert_eq!(udp.dst_port(&data).unwrap(), 6000);
379        assert_eq!(udp.length(&data).unwrap(), 100);
380        assert_eq!(udp.checksum(&data).unwrap(), 0xABCD);
381    }
382
383    #[test]
384    fn test_udp_extract_padding() {
385        // UDP packet with padding
386        let data = [
387            0x30, 0x39, // sport
388            0x00, 0x35, // dport
389            0x00, 0x0c, // len = 12 (8 header + 4 payload)
390            0x00, 0x00, // checksum
391            0x01, 0x02, 0x03, 0x04, // 4 bytes payload
392            0xff, 0xff, 0xff, 0xff, // 4 bytes padding
393        ];
394
395        let index = LayerIndex::new(LayerKind::Udp, 0, 8);
396        let udp = UdpLayer::new(index);
397
398        let (udp_data, padding) = udp.extract_padding(&data);
399        assert_eq!(udp_data.len(), 12); // 8 header + 4 payload
400        assert_eq!(padding.len(), 4); // 4 bytes padding
401    }
402}