Skip to main content

stackforge_core/layer/dns/
svcb.rs

1//! SVCB/HTTPS Service Parameters (RFC 9460).
2
3use std::net::{Ipv4Addr, Ipv6Addr};
4
5use crate::layer::field::FieldError;
6
7/// SVCB/HTTPS Service Parameter keys.
8pub mod svc_key {
9    pub const MANDATORY: u16 = 0;
10    pub const ALPN: u16 = 1;
11    pub const NO_DEFAULT_ALPN: u16 = 2;
12    pub const PORT: u16 = 3;
13    pub const IPV4HINT: u16 = 4;
14    pub const ECH: u16 = 5;
15    pub const IPV6HINT: u16 = 6;
16}
17
18/// A single service parameter from SVCB/HTTPS records.
19#[derive(Debug, Clone, PartialEq)]
20pub enum SvcParam {
21    /// Mandatory keys that must be understood.
22    Mandatory(Vec<u16>),
23    /// Application-Layer Protocol Negotiation.
24    Alpn(Vec<String>),
25    /// No default ALPN (empty value).
26    NoDefaultAlpn,
27    /// Port number.
28    Port(u16),
29    /// IPv4 address hints.
30    Ipv4Hint(Vec<Ipv4Addr>),
31    /// Encrypted Client Hello.
32    Ech(Vec<u8>),
33    /// IPv6 address hints.
34    Ipv6Hint(Vec<Ipv6Addr>),
35    /// Unknown parameter.
36    Unknown { key: u16, value: Vec<u8> },
37}
38
39impl SvcParam {
40    /// Parse a single `SvcParam` from wire format.
41    pub fn parse(key: u16, data: &[u8]) -> Result<Self, FieldError> {
42        match key {
43            svc_key::MANDATORY => {
44                if !data.len().is_multiple_of(2) {
45                    return Err(FieldError::InvalidValue(
46                        "mandatory param length must be even".to_string(),
47                    ));
48                }
49                let keys: Vec<u16> = data
50                    .chunks_exact(2)
51                    .map(|c| u16::from_be_bytes([c[0], c[1]]))
52                    .collect();
53                Ok(SvcParam::Mandatory(keys))
54            },
55
56            svc_key::ALPN => {
57                let mut alpns = Vec::new();
58                let mut pos = 0;
59                while pos < data.len() {
60                    let len = data[pos] as usize;
61                    pos += 1;
62                    if pos + len > data.len() {
63                        return Err(FieldError::BufferTooShort {
64                            offset: pos,
65                            need: len,
66                            have: data.len() - pos,
67                        });
68                    }
69                    alpns.push(String::from_utf8_lossy(&data[pos..pos + len]).into_owned());
70                    pos += len;
71                }
72                Ok(SvcParam::Alpn(alpns))
73            },
74
75            svc_key::NO_DEFAULT_ALPN => Ok(SvcParam::NoDefaultAlpn),
76
77            svc_key::PORT => {
78                if data.len() != 2 {
79                    return Err(FieldError::InvalidValue(
80                        "port param must be 2 bytes".to_string(),
81                    ));
82                }
83                Ok(SvcParam::Port(u16::from_be_bytes([data[0], data[1]])))
84            },
85
86            svc_key::IPV4HINT => {
87                if !data.len().is_multiple_of(4) {
88                    return Err(FieldError::InvalidValue(
89                        "ipv4hint length must be multiple of 4".to_string(),
90                    ));
91                }
92                let addrs: Vec<Ipv4Addr> = data
93                    .chunks_exact(4)
94                    .map(|c| Ipv4Addr::new(c[0], c[1], c[2], c[3]))
95                    .collect();
96                Ok(SvcParam::Ipv4Hint(addrs))
97            },
98
99            svc_key::ECH => Ok(SvcParam::Ech(data.to_vec())),
100
101            svc_key::IPV6HINT => {
102                if !data.len().is_multiple_of(16) {
103                    return Err(FieldError::InvalidValue(
104                        "ipv6hint length must be multiple of 16".to_string(),
105                    ));
106                }
107                let addrs: Vec<Ipv6Addr> = data
108                    .chunks_exact(16)
109                    .map(|c| {
110                        let mut arr = [0u8; 16];
111                        arr.copy_from_slice(c);
112                        Ipv6Addr::from(arr)
113                    })
114                    .collect();
115                Ok(SvcParam::Ipv6Hint(addrs))
116            },
117
118            _ => Ok(SvcParam::Unknown {
119                key,
120                value: data.to_vec(),
121            }),
122        }
123    }
124
125    /// Get the key for this parameter.
126    #[must_use]
127    pub fn key(&self) -> u16 {
128        match self {
129            SvcParam::Mandatory(_) => svc_key::MANDATORY,
130            SvcParam::Alpn(_) => svc_key::ALPN,
131            SvcParam::NoDefaultAlpn => svc_key::NO_DEFAULT_ALPN,
132            SvcParam::Port(_) => svc_key::PORT,
133            SvcParam::Ipv4Hint(_) => svc_key::IPV4HINT,
134            SvcParam::Ech(_) => svc_key::ECH,
135            SvcParam::Ipv6Hint(_) => svc_key::IPV6HINT,
136            SvcParam::Unknown { key, .. } => *key,
137        }
138    }
139
140    /// Serialize the parameter value (without key and length header).
141    #[must_use]
142    pub fn build_value(&self) -> Vec<u8> {
143        match self {
144            SvcParam::Mandatory(keys) => keys.iter().flat_map(|k| k.to_be_bytes()).collect(),
145            SvcParam::Alpn(alpns) => {
146                let mut out = Vec::new();
147                for alpn in alpns {
148                    out.push(alpn.len() as u8);
149                    out.extend_from_slice(alpn.as_bytes());
150                }
151                out
152            },
153            SvcParam::NoDefaultAlpn => Vec::new(),
154            SvcParam::Port(port) => port.to_be_bytes().to_vec(),
155            SvcParam::Ipv4Hint(addrs) => {
156                addrs.iter().flat_map(std::net::Ipv4Addr::octets).collect()
157            },
158            SvcParam::Ech(data) => data.clone(),
159            SvcParam::Ipv6Hint(addrs) => {
160                addrs.iter().flat_map(std::net::Ipv6Addr::octets).collect()
161            },
162            SvcParam::Unknown { value, .. } => value.clone(),
163        }
164    }
165
166    /// Build the complete key-value pair (key + length + value).
167    #[must_use]
168    pub fn build(&self) -> Vec<u8> {
169        let value = self.build_value();
170        let mut out = Vec::with_capacity(4 + value.len());
171        out.extend_from_slice(&self.key().to_be_bytes());
172        out.extend_from_slice(&(value.len() as u16).to_be_bytes());
173        out.extend_from_slice(&value);
174        out
175    }
176
177    /// Parse all `SvcParams` from wire format.
178    pub fn parse_all(data: &[u8]) -> Result<Vec<Self>, FieldError> {
179        let mut params = Vec::new();
180        let mut pos = 0;
181
182        while pos < data.len() {
183            if pos + 4 > data.len() {
184                return Err(FieldError::BufferTooShort {
185                    offset: pos,
186                    need: 4,
187                    have: data.len(),
188                });
189            }
190            let key = u16::from_be_bytes([data[pos], data[pos + 1]]);
191            let val_len = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
192            pos += 4;
193
194            if pos + val_len > data.len() {
195                return Err(FieldError::BufferTooShort {
196                    offset: pos,
197                    need: val_len,
198                    have: data.len() - pos,
199                });
200            }
201
202            let param = Self::parse(key, &data[pos..pos + val_len])?;
203            params.push(param);
204            pos += val_len;
205        }
206
207        Ok(params)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_alpn_roundtrip() {
217        let param = SvcParam::Alpn(vec!["h2".to_string(), "h3".to_string()]);
218        let built = param.build();
219        let parsed = SvcParam::parse_all(&built).unwrap();
220        assert_eq!(parsed.len(), 1);
221        assert_eq!(parsed[0], param);
222    }
223
224    #[test]
225    fn test_port_roundtrip() {
226        let param = SvcParam::Port(443);
227        let built = param.build();
228        let parsed = SvcParam::parse_all(&built).unwrap();
229        assert_eq!(parsed[0], param);
230    }
231
232    #[test]
233    fn test_ipv4hint_roundtrip() {
234        let param = SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)]);
235        let built = param.build();
236        let parsed = SvcParam::parse_all(&built).unwrap();
237        assert_eq!(parsed[0], param);
238    }
239
240    #[test]
241    fn test_ipv6hint_roundtrip() {
242        let param = SvcParam::Ipv6Hint(vec![Ipv6Addr::LOCALHOST]);
243        let built = param.build();
244        let parsed = SvcParam::parse_all(&built).unwrap();
245        assert_eq!(parsed[0], param);
246    }
247
248    #[test]
249    fn test_mandatory_roundtrip() {
250        let param = SvcParam::Mandatory(vec![1, 3]); // alpn, port
251        let built = param.build();
252        let parsed = SvcParam::parse_all(&built).unwrap();
253        assert_eq!(parsed[0], param);
254    }
255
256    #[test]
257    fn test_no_default_alpn() {
258        let param = SvcParam::NoDefaultAlpn;
259        let built = param.build();
260        assert_eq!(built.len(), 4); // key(2) + len(2) + value(0)
261        let parsed = SvcParam::parse_all(&built).unwrap();
262        assert_eq!(parsed[0], param);
263    }
264
265    #[test]
266    fn test_multiple_params() {
267        let params = vec![
268            SvcParam::Alpn(vec!["h2".to_string()]),
269            SvcParam::Port(443),
270            SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 1, 1, 1)]),
271        ];
272        let mut data = Vec::new();
273        for p in &params {
274            data.extend_from_slice(&p.build());
275        }
276        let parsed = SvcParam::parse_all(&data).unwrap();
277        assert_eq!(parsed, params);
278    }
279}