stackforge_core/layer/dns/
svcb.rs1use std::net::{Ipv4Addr, Ipv6Addr};
4
5use crate::layer::field::FieldError;
6
7pub mod svc_key {
9 pub const MANDATORY: u16 = 0;
10 pub const ALPN: u16 = 1;
11 pub const NO_DEFAULT_ALPN: u16 = 2;
12 pub const PORT: u16 = 3;
13 pub const IPV4HINT: u16 = 4;
14 pub const ECH: u16 = 5;
15 pub const IPV6HINT: u16 = 6;
16}
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum SvcParam {
21 Mandatory(Vec<u16>),
23 Alpn(Vec<String>),
25 NoDefaultAlpn,
27 Port(u16),
29 Ipv4Hint(Vec<Ipv4Addr>),
31 Ech(Vec<u8>),
33 Ipv6Hint(Vec<Ipv6Addr>),
35 Unknown { key: u16, value: Vec<u8> },
37}
38
39impl SvcParam {
40 pub fn parse(key: u16, data: &[u8]) -> Result<Self, FieldError> {
42 match key {
43 svc_key::MANDATORY => {
44 if data.len() % 2 != 0 {
45 return Err(FieldError::InvalidValue(
46 "mandatory param length must be even".to_string(),
47 ));
48 }
49 let keys: Vec<u16> = data
50 .chunks_exact(2)
51 .map(|c| u16::from_be_bytes([c[0], c[1]]))
52 .collect();
53 Ok(SvcParam::Mandatory(keys))
54 },
55
56 svc_key::ALPN => {
57 let mut alpns = Vec::new();
58 let mut pos = 0;
59 while pos < data.len() {
60 let len = data[pos] as usize;
61 pos += 1;
62 if pos + len > data.len() {
63 return Err(FieldError::BufferTooShort {
64 offset: pos,
65 need: len,
66 have: data.len() - pos,
67 });
68 }
69 alpns.push(String::from_utf8_lossy(&data[pos..pos + len]).into_owned());
70 pos += len;
71 }
72 Ok(SvcParam::Alpn(alpns))
73 },
74
75 svc_key::NO_DEFAULT_ALPN => Ok(SvcParam::NoDefaultAlpn),
76
77 svc_key::PORT => {
78 if data.len() != 2 {
79 return Err(FieldError::InvalidValue(
80 "port param must be 2 bytes".to_string(),
81 ));
82 }
83 Ok(SvcParam::Port(u16::from_be_bytes([data[0], data[1]])))
84 },
85
86 svc_key::IPV4HINT => {
87 if data.len() % 4 != 0 {
88 return Err(FieldError::InvalidValue(
89 "ipv4hint length must be multiple of 4".to_string(),
90 ));
91 }
92 let addrs: Vec<Ipv4Addr> = data
93 .chunks_exact(4)
94 .map(|c| Ipv4Addr::new(c[0], c[1], c[2], c[3]))
95 .collect();
96 Ok(SvcParam::Ipv4Hint(addrs))
97 },
98
99 svc_key::ECH => Ok(SvcParam::Ech(data.to_vec())),
100
101 svc_key::IPV6HINT => {
102 if data.len() % 16 != 0 {
103 return Err(FieldError::InvalidValue(
104 "ipv6hint length must be multiple of 16".to_string(),
105 ));
106 }
107 let addrs: Vec<Ipv6Addr> = data
108 .chunks_exact(16)
109 .map(|c| {
110 let mut arr = [0u8; 16];
111 arr.copy_from_slice(c);
112 Ipv6Addr::from(arr)
113 })
114 .collect();
115 Ok(SvcParam::Ipv6Hint(addrs))
116 },
117
118 _ => Ok(SvcParam::Unknown {
119 key,
120 value: data.to_vec(),
121 }),
122 }
123 }
124
125 pub fn key(&self) -> u16 {
127 match self {
128 SvcParam::Mandatory(_) => svc_key::MANDATORY,
129 SvcParam::Alpn(_) => svc_key::ALPN,
130 SvcParam::NoDefaultAlpn => svc_key::NO_DEFAULT_ALPN,
131 SvcParam::Port(_) => svc_key::PORT,
132 SvcParam::Ipv4Hint(_) => svc_key::IPV4HINT,
133 SvcParam::Ech(_) => svc_key::ECH,
134 SvcParam::Ipv6Hint(_) => svc_key::IPV6HINT,
135 SvcParam::Unknown { key, .. } => *key,
136 }
137 }
138
139 pub fn build_value(&self) -> Vec<u8> {
141 match self {
142 SvcParam::Mandatory(keys) => keys.iter().flat_map(|k| k.to_be_bytes()).collect(),
143 SvcParam::Alpn(alpns) => {
144 let mut out = Vec::new();
145 for alpn in alpns {
146 out.push(alpn.len() as u8);
147 out.extend_from_slice(alpn.as_bytes());
148 }
149 out
150 },
151 SvcParam::NoDefaultAlpn => Vec::new(),
152 SvcParam::Port(port) => port.to_be_bytes().to_vec(),
153 SvcParam::Ipv4Hint(addrs) => addrs.iter().flat_map(|a| a.octets()).collect(),
154 SvcParam::Ech(data) => data.clone(),
155 SvcParam::Ipv6Hint(addrs) => addrs.iter().flat_map(|a| a.octets()).collect(),
156 SvcParam::Unknown { value, .. } => value.clone(),
157 }
158 }
159
160 pub fn build(&self) -> Vec<u8> {
162 let value = self.build_value();
163 let mut out = Vec::with_capacity(4 + value.len());
164 out.extend_from_slice(&self.key().to_be_bytes());
165 out.extend_from_slice(&(value.len() as u16).to_be_bytes());
166 out.extend_from_slice(&value);
167 out
168 }
169
170 pub fn parse_all(data: &[u8]) -> Result<Vec<Self>, FieldError> {
172 let mut params = Vec::new();
173 let mut pos = 0;
174
175 while pos < data.len() {
176 if pos + 4 > data.len() {
177 return Err(FieldError::BufferTooShort {
178 offset: pos,
179 need: 4,
180 have: data.len(),
181 });
182 }
183 let key = u16::from_be_bytes([data[pos], data[pos + 1]]);
184 let val_len = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
185 pos += 4;
186
187 if pos + val_len > data.len() {
188 return Err(FieldError::BufferTooShort {
189 offset: pos,
190 need: val_len,
191 have: data.len() - pos,
192 });
193 }
194
195 let param = Self::parse(key, &data[pos..pos + val_len])?;
196 params.push(param);
197 pos += val_len;
198 }
199
200 Ok(params)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_alpn_roundtrip() {
210 let param = SvcParam::Alpn(vec!["h2".to_string(), "h3".to_string()]);
211 let built = param.build();
212 let parsed = SvcParam::parse_all(&built).unwrap();
213 assert_eq!(parsed.len(), 1);
214 assert_eq!(parsed[0], param);
215 }
216
217 #[test]
218 fn test_port_roundtrip() {
219 let param = SvcParam::Port(443);
220 let built = param.build();
221 let parsed = SvcParam::parse_all(&built).unwrap();
222 assert_eq!(parsed[0], param);
223 }
224
225 #[test]
226 fn test_ipv4hint_roundtrip() {
227 let param = SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)]);
228 let built = param.build();
229 let parsed = SvcParam::parse_all(&built).unwrap();
230 assert_eq!(parsed[0], param);
231 }
232
233 #[test]
234 fn test_ipv6hint_roundtrip() {
235 let param = SvcParam::Ipv6Hint(vec![Ipv6Addr::LOCALHOST]);
236 let built = param.build();
237 let parsed = SvcParam::parse_all(&built).unwrap();
238 assert_eq!(parsed[0], param);
239 }
240
241 #[test]
242 fn test_mandatory_roundtrip() {
243 let param = SvcParam::Mandatory(vec![1, 3]); let built = param.build();
245 let parsed = SvcParam::parse_all(&built).unwrap();
246 assert_eq!(parsed[0], param);
247 }
248
249 #[test]
250 fn test_no_default_alpn() {
251 let param = SvcParam::NoDefaultAlpn;
252 let built = param.build();
253 assert_eq!(built.len(), 4); let parsed = SvcParam::parse_all(&built).unwrap();
255 assert_eq!(parsed[0], param);
256 }
257
258 #[test]
259 fn test_multiple_params() {
260 let params = vec![
261 SvcParam::Alpn(vec!["h2".to_string()]),
262 SvcParam::Port(443),
263 SvcParam::Ipv4Hint(vec![Ipv4Addr::new(1, 1, 1, 1)]),
264 ];
265 let mut data = Vec::new();
266 for p in ¶ms {
267 data.extend_from_slice(&p.build());
268 }
269 let parsed = SvcParam::parse_all(&data).unwrap();
270 assert_eq!(parsed, params);
271 }
272}