1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
7
8use super::cursor::{Cursor, CursorMut};
9use super::Error;
10
11#[derive(Clone, Debug, PartialEq)]
13pub struct ChallengeResponse {
14 pub master_challenge: u32,
16 pub server_challenge: Option<u32>,
18}
19
20impl ChallengeResponse {
21 pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffs\n";
23
24 pub fn new(master_challenge: u32, server_challenge: Option<u32>) -> Self {
26 Self {
27 master_challenge,
28 server_challenge,
29 }
30 }
31
32 pub fn decode(src: &[u8]) -> Result<Self, Error> {
34 let mut cur = Cursor::new(src);
35 cur.expect(Self::HEADER)?;
36 let master_challenge = cur.get_u32_le()?;
37 let server_challenge = if cur.remaining() == 4 {
38 Some(cur.get_u32_le()?)
39 } else {
40 None
41 };
42 cur.expect_empty()?;
43 Ok(Self {
44 master_challenge,
45 server_challenge,
46 })
47 }
48
49 pub fn encode<const N: usize>(&self, buf: &mut [u8; N]) -> Result<usize, Error> {
51 let mut cur = CursorMut::new(buf);
52 cur.put_bytes(Self::HEADER)?;
53 cur.put_u32_le(self.master_challenge)?;
54 if let Some(server_challenge) = self.server_challenge {
55 cur.put_u32_le(server_challenge)?;
56 }
57 Ok(cur.pos())
58 }
59}
60
61pub trait ServerAddress: Sized {
63 fn size() -> usize;
65
66 fn get(cur: &mut Cursor) -> Result<Self, Error>;
68
69 fn put(&self, cur: &mut CursorMut) -> Result<(), Error>;
71}
72
73impl ServerAddress for SocketAddrV4 {
74 fn size() -> usize {
75 6
76 }
77
78 fn get(cur: &mut Cursor) -> Result<Self, Error> {
79 let ip = Ipv4Addr::from(cur.get_array()?);
80 let port = cur.get_u16_be()?;
81 Ok(SocketAddrV4::new(ip, port))
82 }
83
84 fn put(&self, cur: &mut CursorMut) -> Result<(), Error> {
85 cur.put_array(&self.ip().octets())?;
86 cur.put_u16_be(self.port())?;
87 Ok(())
88 }
89}
90
91impl ServerAddress for SocketAddrV6 {
92 fn size() -> usize {
93 18
94 }
95
96 fn get(cur: &mut Cursor) -> Result<Self, Error> {
97 let ip = Ipv6Addr::from(cur.get_array()?);
98 let port = cur.get_u16_be()?;
99 Ok(SocketAddrV6::new(ip, port, 0, 0))
100 }
101
102 fn put(&self, cur: &mut CursorMut) -> Result<(), Error> {
103 cur.put_array(&self.ip().octets())?;
104 cur.put_u16_be(self.port())?;
105 Ok(())
106 }
107}
108
109#[derive(Clone, Debug, PartialEq)]
111pub struct QueryServersResponse<I> {
112 inner: I,
113 pub key: Option<u32>,
115}
116
117impl QueryServersResponse<()> {
118 pub const HEADER: &'static [u8] = b"\xff\xff\xff\xfff\n";
120}
121
122impl<'a> QueryServersResponse<&'a [u8]> {
123 pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
125 let mut cur = Cursor::new(src);
126 cur.expect(QueryServersResponse::HEADER)?;
127 let s = cur.end();
128
129 let (inner, key) = if s.len() >= 6 && s[0] == 0x7f && s[5] == 8 {
131 let key = u32::from_le_bytes([s[1], s[2], s[3], s[4]]);
132 (&s[6..], Some(key))
133 } else {
134 (s, None)
135 };
136
137 Ok(Self { inner, key })
138 }
139
140 pub fn iter<A>(&self) -> impl 'a + Iterator<Item = A>
142 where
143 A: ServerAddress,
144 {
145 let mut cur = Cursor::new(self.inner);
146 std::iter::from_fn(move || {
147 if cur.remaining() == A::size() && cur.end().ends_with(&[0; 2]) {
148 return None;
150 }
151 A::get(&mut cur).ok()
152 })
153 }
154
155 pub fn is_empty(&self) -> bool {
157 self.inner.is_empty()
158 }
159}
160
161impl QueryServersResponse<()> {
162 pub fn new(key: Option<u32>) -> Self {
164 Self { inner: (), key }
165 }
166
167 pub fn encode<A>(&mut self, buf: &mut [u8], list: &[A]) -> Result<(usize, usize), Error>
174 where
175 A: ServerAddress,
176 {
177 let mut cur = CursorMut::new(buf);
178 cur.put_bytes(QueryServersResponse::HEADER)?;
179 if let Some(key) = self.key {
180 cur.put_u8(0x7f)?;
181 cur.put_u32_le(key)?;
182 cur.put_u8(8)?;
183 }
184 let mut count = 0;
185 let mut iter = list.iter();
186 while cur.remaining() >= A::size() * 2 {
187 if let Some(i) = iter.next() {
188 i.put(&mut cur)?;
189 count += 1;
190 } else {
191 break;
192 }
193 }
194 for _ in 0..A::size() {
195 cur.put_u8(0)?;
196 }
197 Ok((cur.pos(), count))
198 }
199}
200
201#[derive(Clone, Debug, PartialEq)]
203pub struct ClientAnnounce {
204 pub addr: SocketAddr,
206}
207
208impl ClientAnnounce {
209 pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffc ";
211
212 pub fn new(addr: SocketAddr) -> Self {
214 Self { addr }
215 }
216
217 pub fn decode(src: &[u8]) -> Result<Self, Error> {
219 let mut cur = Cursor::new(src);
220 cur.expect(Self::HEADER)?;
221 let addr = cur
222 .get_str(cur.remaining())?
223 .parse()
224 .map_err(|_| Error::InvalidClientAnnounceIp)?;
225 cur.expect_empty()?;
226 Ok(Self { addr })
227 }
228
229 pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
231 Ok(CursorMut::new(buf)
232 .put_bytes(Self::HEADER)?
233 .put_as_str(self.addr)?
234 .pos())
235 }
236}
237
238#[derive(Clone, Debug, PartialEq)]
240pub struct AdminChallengeResponse {
241 pub master_challenge: u32,
243 pub hash_challenge: u32,
245}
246
247impl AdminChallengeResponse {
248 pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffadminchallenge";
250
251 pub fn new(master_challenge: u32, hash_challenge: u32) -> Self {
253 Self {
254 master_challenge,
255 hash_challenge,
256 }
257 }
258
259 pub fn decode(src: &[u8]) -> Result<Self, Error> {
261 let mut cur = Cursor::new(src);
262 cur.expect(Self::HEADER)?;
263 let master_challenge = cur.get_u32_le()?;
264 let hash_challenge = cur.get_u32_le()?;
265 cur.expect_empty()?;
266 Ok(Self {
267 master_challenge,
268 hash_challenge,
269 })
270 }
271
272 pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
274 Ok(CursorMut::new(buf)
275 .put_bytes(Self::HEADER)?
276 .put_u32_le(self.master_challenge)?
277 .put_u32_le(self.hash_challenge)?
278 .pos())
279 }
280}
281
282#[derive(Clone, Debug, PartialEq)]
284pub enum Packet<'a> {
285 ChallengeResponse(ChallengeResponse),
287 QueryServersResponse(QueryServersResponse<&'a [u8]>),
289 ClientAnnounce(ClientAnnounce),
291 AdminChallengeResponse(AdminChallengeResponse),
293}
294
295impl<'a> Packet<'a> {
296 pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
298 if src.starts_with(ChallengeResponse::HEADER) {
299 ChallengeResponse::decode(src).map(Self::ChallengeResponse)
300 } else if src.starts_with(QueryServersResponse::HEADER) {
301 QueryServersResponse::decode(src).map(Self::QueryServersResponse)
302 } else if src.starts_with(ClientAnnounce::HEADER) {
303 ClientAnnounce::decode(src).map(Self::ClientAnnounce)
304 } else if src.starts_with(AdminChallengeResponse::HEADER) {
305 AdminChallengeResponse::decode(src).map(Self::AdminChallengeResponse)
306 } else {
307 return Ok(None);
308 }
309 .map(Some)
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn challenge_response() {
319 let p = ChallengeResponse::new(0x12345678, Some(0x87654321));
320 let mut buf = [0; 512];
321 let n = p.encode(&mut buf).unwrap();
322 assert_eq!(
323 Packet::decode(&buf[..n]),
324 Ok(Some(Packet::ChallengeResponse(p)))
325 );
326 }
327
328 #[test]
329 fn challenge_response_old() {
330 let s = b"\xff\xff\xff\xffs\n\x78\x56\x34\x12";
331 assert_eq!(
332 ChallengeResponse::decode(s),
333 Ok(ChallengeResponse::new(0x12345678, None))
334 );
335
336 let p = ChallengeResponse::new(0x12345678, None);
337 let mut buf = [0; 512];
338 let n = p.encode(&mut buf).unwrap();
339 assert_eq!(
340 Packet::decode(&buf[..n]),
341 Ok(Some(Packet::ChallengeResponse(p)))
342 );
343 }
344
345 #[test]
346 fn query_servers_response_ipv4() {
347 type Addr = SocketAddrV4;
348 let servers: &[Addr] = &[
349 "1.2.3.4:27001".parse().unwrap(),
350 "1.2.3.4:27002".parse().unwrap(),
351 "1.2.3.4:27003".parse().unwrap(),
352 "1.2.3.4:27004".parse().unwrap(),
353 ];
354 let mut p = QueryServersResponse::new(Some(0xdeadbeef));
355 let mut buf = [0; 512];
356 let (n, c) = p.encode(&mut buf, servers).unwrap();
357 assert_eq!(c, servers.len());
358 assert_eq!(n, 12 + Addr::size() * (servers.len() + 1));
359 let e = QueryServersResponse::decode(&buf[..n]).unwrap();
360 assert_eq!(e.iter::<Addr>().collect::<Vec<_>>(), servers);
361 }
362
363 #[test]
364 fn query_servers_response_ipv6() {
365 type Addr = SocketAddrV6;
366 let servers: &[Addr] = &[
367 "[::1]:27001".parse().unwrap(),
368 "[::2]:27002".parse().unwrap(),
369 "[::3]:27003".parse().unwrap(),
370 "[::4]:27004".parse().unwrap(),
371 ];
372 let mut p = QueryServersResponse::new(Some(0xdeadbeef));
373 let mut buf = [0; 512];
374 let (n, c) = p.encode(&mut buf, servers).unwrap();
375 assert_eq!(c, servers.len());
376 assert_eq!(n, 12 + Addr::size() * (servers.len() + 1));
377 let e = QueryServersResponse::decode(&buf[..n]).unwrap();
378 assert_eq!(e.iter::<Addr>().collect::<Vec<_>>(), servers);
379 }
380
381 #[test]
382 fn client_announce() {
383 let p = ClientAnnounce::new("1.2.3.4:12345".parse().unwrap());
384 let mut buf = [0; 512];
385 let n = p.encode(&mut buf).unwrap();
386 assert_eq!(
387 Packet::decode(&buf[..n]),
388 Ok(Some(Packet::ClientAnnounce(p)))
389 );
390 }
391
392 #[test]
393 fn admin_challenge_response() {
394 let p = AdminChallengeResponse::new(0x12345678, 0x87654321);
395 let mut buf = [0; 64];
396 let n = p.encode(&mut buf).unwrap();
397 assert_eq!(
398 Packet::decode(&buf[..n]),
399 Ok(Some(Packet::AdminChallengeResponse(p)))
400 );
401 }
402}