Skip to main content

stackforge_core/layer/dns/
svcb.rs

1//! SVCB/HTTPS Service Parameters (RFC 9460).
2
3use std::net::{Ipv4Addr, Ipv6Addr};
4
5use crate::layer::field::FieldError;
6
7/// SVCB/HTTPS Service Parameter keys.
8pub mod svc_key {
9    pub const MANDATORY: u16 = 0;
10    pub const ALPN: u16 = 1;
11    pub const NO_DEFAULT_ALPN: u16 = 2;
12    pub const PORT: u16 = 3;
13    pub const IPV4HINT: u16 = 4;
14    pub const ECH: u16 = 5;
15    pub const IPV6HINT: u16 = 6;
16}
17
18/// A single service parameter from SVCB/HTTPS records.
19#[derive(Debug, Clone, PartialEq)]
20pub enum SvcParam {
21    /// Mandatory keys that must be understood.
22    Mandatory(Vec<u16>),
23    /// Application-Layer Protocol Negotiation.
24    Alpn(Vec<String>),
25    /// No default ALPN (empty value).
26    NoDefaultAlpn,
27    /// Port number.
28    Port(u16),
29    /// IPv4 address hints.
30    Ipv4Hint(Vec<Ipv4Addr>),
31    /// Encrypted Client Hello.
32    Ech(Vec<u8>),
33    /// IPv6 address hints.
34    Ipv6Hint(Vec<Ipv6Addr>),
35    /// Unknown parameter.
36    Unknown { key: u16, value: Vec<u8> },
37}
38
39impl SvcParam {
40    /// Parse a single SvcParam from wire format.
41    pub fn parse(key: u16, data: &[u8]) -> Result<Self, FieldError> {
42        match key {
43            svc_key::MANDATORY => {
44                if data.len() % 2 != 0 {
45                    return Err(FieldError::InvalidValue(
46                        "mandatory param length must be even".to_string(),
47                    ));
48                }
49                let keys: Vec<u16> = data
50                    .chunks_exact(2)
51                    .map(|c| u16::from_be_bytes([c[0], c[1]]))
52                    .collect();
53                Ok(SvcParam::Mandatory(keys))
54            },
55
56            svc_key::ALPN => {
57                let mut alpns = Vec::new();
58                let mut pos = 0;
59                while pos < data.len() {
60                    let len = data[pos] as usize;
61                    pos += 1;
62                    if pos + len > data.len() {
63                        return Err(FieldError::BufferTooShort {
64                            offset: pos,
65                            need: len,
66                            have: data.len() - pos,
67                        });
68                    }
69                    alpns.push(String::from_utf8_lossy(&data[pos..pos + len]).into_owned());
70                    pos += len;
71                }
72                Ok(SvcParam::Alpn(alpns))
73            },
74
75            svc_key::NO_DEFAULT_ALPN => Ok(SvcParam::NoDefaultAlpn),
76
77            svc_key::PORT => {
78                if data.len() != 2 {
79                    return Err(FieldError::InvalidValue(
80                        "port param must be 2 bytes".to_string(),
81                    ));
82                }
83                Ok(SvcParam::Port(u16::from_be_bytes([data[0], data[1]])))
84            },
85
86            svc_key::IPV4HINT => {
87                if data.len() % 4 != 0 {
88                    return Err(FieldError::InvalidValue(
89                        "ipv4hint length must be multiple of 4".to_string(),
90                    ));
91                }
92                let addrs: Vec<Ipv4Addr> = data
93                    .chunks_exact(4)
94                    .map(|c| Ipv4Addr::new(c[0], c[1], c[2], c[3]))
95                    .collect();
96                Ok(SvcParam::Ipv4Hint(addrs))
97            },
98
99            svc_key::ECH => Ok(SvcParam::Ech(data.to_vec())),
100
101            svc_key::IPV6HINT => {
102                if data.len() % 16 != 0 {
103                    return Err(FieldError::InvalidValue(
104                        "ipv6hint length must be multiple of 16".to_string(),
105                    ));
106                }
107                let addrs: Vec<Ipv6Addr> = data
108                    .chunks_exact(16)
109                    .map(|c| {
110                        let mut arr = [0u8; 16];
111                        arr.copy_from_slice(c);
112                        Ipv6Addr::from(arr)
113                    })
114                    .collect();
115                Ok(SvcParam::Ipv6Hint(addrs))
116            },
117
118            _ => Ok(SvcParam::Unknown {
119                key,
120                value: data.to_vec(),
121            }),
122        }
123    }
124
125    /// Get the key for this parameter.
126    pub fn key(&self) -> u16 {
127        match self {
128            SvcParam::Mandatory(_) => svc_key::MANDATORY,
129            SvcParam::Alpn(_) => svc_key::ALPN,
130            SvcParam::NoDefaultAlpn => svc_key::NO_DEFAULT_ALPN,
131            SvcParam::Port(_) => svc_key::PORT,
132            SvcParam::Ipv4Hint(_) => svc_key::IPV4HINT,
133            SvcParam::Ech(_) => svc_key::ECH,
134            SvcParam::Ipv6Hint(_) => svc_key::IPV6HINT,
135            SvcParam::Unknown { key, .. } => *key,
136        }
137    }
138
139    /// Serialize the parameter value (without key and length header).
140    pub fn build_value(&self) -> Vec<u8> {
141        match self {
142            SvcParam::Mandatory(keys) => keys.iter().flat_map(|k| k.to_be_bytes()).collect(),
143            SvcParam::Alpn(alpns) => {
144                let mut out = Vec::new();
145                for alpn in alpns {
146                    out.push(alpn.len() as u8);
147                    out.extend_from_slice(alpn.as_bytes());
148                }
149                out
150            },
151            SvcParam::NoDefaultAlpn => Vec::new(),
152            SvcParam::Port(port) => port.to_be_bytes().to_vec(),
153            SvcParam::Ipv4Hint(addrs) => addrs.iter().flat_map(|a| a.octets()).collect(),
154            SvcParam::Ech(data) => data.clone(),
155            SvcParam::Ipv6Hint(addrs) => addrs.iter().flat_map(|a| a.octets()).collect(),
156            SvcParam::Unknown { value, .. } => value.clone(),
157        }
158    }
159
160    /// Build the complete key-value pair (key + length + value).
161    pub fn build(&self) -> Vec<u8> {
162        let value = self.build_value();
163        let mut out = Vec::with_capacity(4 + value.len());
164        out.extend_from_slice(&self.key().to_be_bytes());
165        out.extend_from_slice(&(value.len() as u16).to_be_bytes());
166        out.extend_from_slice(&value);
167        out
168    }
169
170    /// Parse all SvcParams from wire format.
171    pub fn parse_all(data: &[u8]) -> Result<Vec<Self>, FieldError> {
172        let mut params = Vec::new();
173        let mut pos = 0;
174
175        while pos < data.len() {
176            if pos + 4 > data.len() {
177                return Err(FieldError::BufferTooShort {
178                    offset: pos,
179                    need: 4,
180                    have: data.len(),
181                });
182            }
183            let key = u16::from_be_bytes([data[pos], data[pos + 1]]);
184            let val_len = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
185            pos += 4;
186
187            if pos + val_len > data.len() {
188                return Err(FieldError::BufferTooShort {
189                    offset: pos,
190                    need: val_len,
191                    have: data.len() - pos,
192                });
193            }
194
195            let param = Self::parse(key, &data[pos..pos + val_len])?;
196            params.push(param);
197            pos += val_len;
198        }
199
200        Ok(params)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_alpn_roundtrip() {
210        let param = SvcParam::Alpn(vec!["h2".to_string(), "h3".to_string()]);
211        let built = param.build();
212        let parsed = SvcParam::parse_all(&built).unwrap();
213        assert_eq!(parsed.len(), 1);
214        assert_eq!(parsed[0], param);
215    }
216
217    #[test]
218    fn test_port_roundtrip() {
219        let param = SvcParam::Port(443);
220        let built = param.build();
221        let parsed = SvcParam::parse_all(&built).unwrap();
222        assert_eq!(parsed[0], param);
223    }
224
225    #[test]
226    fn test_ipv4hint_roundtrip() {
227        let param = SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)]);
228        let built = param.build();
229        let parsed = SvcParam::parse_all(&built).unwrap();
230        assert_eq!(parsed[0], param);
231    }
232
233    #[test]
234    fn test_ipv6hint_roundtrip() {
235        let param = SvcParam::Ipv6Hint(vec![Ipv6Addr::LOCALHOST]);
236        let built = param.build();
237        let parsed = SvcParam::parse_all(&built).unwrap();
238        assert_eq!(parsed[0], param);
239    }
240
241    #[test]
242    fn test_mandatory_roundtrip() {
243        let param = SvcParam::Mandatory(vec![1, 3]); // alpn, port
244        let built = param.build();
245        let parsed = SvcParam::parse_all(&built).unwrap();
246        assert_eq!(parsed[0], param);
247    }
248
249    #[test]
250    fn test_no_default_alpn() {
251        let param = SvcParam::NoDefaultAlpn;
252        let built = param.build();
253        assert_eq!(built.len(), 4); // key(2) + len(2) + value(0)
254        let parsed = SvcParam::parse_all(&built).unwrap();
255        assert_eq!(parsed[0], param);
256    }
257
258    #[test]
259    fn test_multiple_params() {
260        let params = vec![
261            SvcParam::Alpn(vec!["h2".to_string()]),
262            SvcParam::Port(443),
263            SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 1, 1, 1)]),
264        ];
265        let mut data = Vec::new();
266        for p in &params {
267            data.extend_from_slice(&p.build());
268        }
269        let parsed = SvcParam::parse_all(&data).unwrap();
270        assert_eq!(parsed, params);
271    }
272}