stackforge_core/layer/dns/
svcb.rs1use std::net::{Ipv4Addr, Ipv6Addr};
4
5use crate::layer::field::FieldError;
6
7pub mod svc_key {
9 pub const MANDATORY: u16 = 0;
10 pub const ALPN: u16 = 1;
11 pub const NO_DEFAULT_ALPN: u16 = 2;
12 pub const PORT: u16 = 3;
13 pub const IPV4HINT: u16 = 4;
14 pub const ECH: u16 = 5;
15 pub const IPV6HINT: u16 = 6;
16}
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum SvcParam {
21 Mandatory(Vec<u16>),
23 Alpn(Vec<String>),
25 NoDefaultAlpn,
27 Port(u16),
29 Ipv4Hint(Vec<Ipv4Addr>),
31 Ech(Vec<u8>),
33 Ipv6Hint(Vec<Ipv6Addr>),
35 Unknown { key: u16, value: Vec<u8> },
37}
38
39impl SvcParam {
40 pub fn parse(key: u16, data: &[u8]) -> Result<Self, FieldError> {
42 match key {
43 svc_key::MANDATORY => {
44 if !data.len().is_multiple_of(2) {
45 return Err(FieldError::InvalidValue(
46 "mandatory param length must be even".to_string(),
47 ));
48 }
49 let keys: Vec<u16> = data
50 .chunks_exact(2)
51 .map(|c| u16::from_be_bytes([c[0], c[1]]))
52 .collect();
53 Ok(SvcParam::Mandatory(keys))
54 },
55
56 svc_key::ALPN => {
57 let mut alpns = Vec::new();
58 let mut pos = 0;
59 while pos < data.len() {
60 let len = data[pos] as usize;
61 pos += 1;
62 if pos + len > data.len() {
63 return Err(FieldError::BufferTooShort {
64 offset: pos,
65 need: len,
66 have: data.len() - pos,
67 });
68 }
69 alpns.push(String::from_utf8_lossy(&data[pos..pos + len]).into_owned());
70 pos += len;
71 }
72 Ok(SvcParam::Alpn(alpns))
73 },
74
75 svc_key::NO_DEFAULT_ALPN => Ok(SvcParam::NoDefaultAlpn),
76
77 svc_key::PORT => {
78 if data.len() != 2 {
79 return Err(FieldError::InvalidValue(
80 "port param must be 2 bytes".to_string(),
81 ));
82 }
83 Ok(SvcParam::Port(u16::from_be_bytes([data[0], data[1]])))
84 },
85
86 svc_key::IPV4HINT => {
87 if !data.len().is_multiple_of(4) {
88 return Err(FieldError::InvalidValue(
89 "ipv4hint length must be multiple of 4".to_string(),
90 ));
91 }
92 let addrs: Vec<Ipv4Addr> = data
93 .chunks_exact(4)
94 .map(|c| Ipv4Addr::new(c[0], c[1], c[2], c[3]))
95 .collect();
96 Ok(SvcParam::Ipv4Hint(addrs))
97 },
98
99 svc_key::ECH => Ok(SvcParam::Ech(data.to_vec())),
100
101 svc_key::IPV6HINT => {
102 if !data.len().is_multiple_of(16) {
103 return Err(FieldError::InvalidValue(
104 "ipv6hint length must be multiple of 16".to_string(),
105 ));
106 }
107 let addrs: Vec<Ipv6Addr> = data
108 .chunks_exact(16)
109 .map(|c| {
110 let mut arr = [0u8; 16];
111 arr.copy_from_slice(c);
112 Ipv6Addr::from(arr)
113 })
114 .collect();
115 Ok(SvcParam::Ipv6Hint(addrs))
116 },
117
118 _ => Ok(SvcParam::Unknown {
119 key,
120 value: data.to_vec(),
121 }),
122 }
123 }
124
125 #[must_use]
127 pub fn key(&self) -> u16 {
128 match self {
129 SvcParam::Mandatory(_) => svc_key::MANDATORY,
130 SvcParam::Alpn(_) => svc_key::ALPN,
131 SvcParam::NoDefaultAlpn => svc_key::NO_DEFAULT_ALPN,
132 SvcParam::Port(_) => svc_key::PORT,
133 SvcParam::Ipv4Hint(_) => svc_key::IPV4HINT,
134 SvcParam::Ech(_) => svc_key::ECH,
135 SvcParam::Ipv6Hint(_) => svc_key::IPV6HINT,
136 SvcParam::Unknown { key, .. } => *key,
137 }
138 }
139
140 #[must_use]
142 pub fn build_value(&self) -> Vec<u8> {
143 match self {
144 SvcParam::Mandatory(keys) => keys.iter().flat_map(|k| k.to_be_bytes()).collect(),
145 SvcParam::Alpn(alpns) => {
146 let mut out = Vec::new();
147 for alpn in alpns {
148 out.push(alpn.len() as u8);
149 out.extend_from_slice(alpn.as_bytes());
150 }
151 out
152 },
153 SvcParam::NoDefaultAlpn => Vec::new(),
154 SvcParam::Port(port) => port.to_be_bytes().to_vec(),
155 SvcParam::Ipv4Hint(addrs) => {
156 addrs.iter().flat_map(std::net::Ipv4Addr::octets).collect()
157 },
158 SvcParam::Ech(data) => data.clone(),
159 SvcParam::Ipv6Hint(addrs) => {
160 addrs.iter().flat_map(std::net::Ipv6Addr::octets).collect()
161 },
162 SvcParam::Unknown { value, .. } => value.clone(),
163 }
164 }
165
166 #[must_use]
168 pub fn build(&self) -> Vec<u8> {
169 let value = self.build_value();
170 let mut out = Vec::with_capacity(4 + value.len());
171 out.extend_from_slice(&self.key().to_be_bytes());
172 out.extend_from_slice(&(value.len() as u16).to_be_bytes());
173 out.extend_from_slice(&value);
174 out
175 }
176
177 pub fn parse_all(data: &[u8]) -> Result<Vec<Self>, FieldError> {
179 let mut params = Vec::new();
180 let mut pos = 0;
181
182 while pos < data.len() {
183 if pos + 4 > data.len() {
184 return Err(FieldError::BufferTooShort {
185 offset: pos,
186 need: 4,
187 have: data.len(),
188 });
189 }
190 let key = u16::from_be_bytes([data[pos], data[pos + 1]]);
191 let val_len = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
192 pos += 4;
193
194 if pos + val_len > data.len() {
195 return Err(FieldError::BufferTooShort {
196 offset: pos,
197 need: val_len,
198 have: data.len() - pos,
199 });
200 }
201
202 let param = Self::parse(key, &data[pos..pos + val_len])?;
203 params.push(param);
204 pos += val_len;
205 }
206
207 Ok(params)
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_alpn_roundtrip() {
217 let param = SvcParam::Alpn(vec!["h2".to_string(), "h3".to_string()]);
218 let built = param.build();
219 let parsed = SvcParam::parse_all(&built).unwrap();
220 assert_eq!(parsed.len(), 1);
221 assert_eq!(parsed[0], param);
222 }
223
224 #[test]
225 fn test_port_roundtrip() {
226 let param = SvcParam::Port(443);
227 let built = param.build();
228 let parsed = SvcParam::parse_all(&built).unwrap();
229 assert_eq!(parsed[0], param);
230 }
231
232 #[test]
233 fn test_ipv4hint_roundtrip() {
234 let param = SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)]);
235 let built = param.build();
236 let parsed = SvcParam::parse_all(&built).unwrap();
237 assert_eq!(parsed[0], param);
238 }
239
240 #[test]
241 fn test_ipv6hint_roundtrip() {
242 let param = SvcParam::Ipv6Hint(vec![Ipv6Addr::LOCALHOST]);
243 let built = param.build();
244 let parsed = SvcParam::parse_all(&built).unwrap();
245 assert_eq!(parsed[0], param);
246 }
247
248 #[test]
249 fn test_mandatory_roundtrip() {
250 let param = SvcParam::Mandatory(vec![1, 3]); let built = param.build();
252 let parsed = SvcParam::parse_all(&built).unwrap();
253 assert_eq!(parsed[0], param);
254 }
255
256 #[test]
257 fn test_no_default_alpn() {
258 let param = SvcParam::NoDefaultAlpn;
259 let built = param.build();
260 assert_eq!(built.len(), 4); let parsed = SvcParam::parse_all(&built).unwrap();
262 assert_eq!(parsed[0], param);
263 }
264
265 #[test]
266 fn test_multiple_params() {
267 let params = vec![
268 SvcParam::Alpn(vec!["h2".to_string()]),
269 SvcParam::Port(443),
270 SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 1, 1, 1)]),
271 ];
272 let mut data = Vec::new();
273 for p in ¶ms {
274 data.extend_from_slice(&p.build());
275 }
276 let parsed = SvcParam::parse_all(&data).unwrap();
277 assert_eq!(parsed, params);
278 }
279}