1use std::net::{SocketAddr, SocketAddrV4, UdpSocket};
4use std::collections::{HashMap, hash_map};
5use std::time::{Duration, Instant};
6use std::sync::{Arc, Mutex};
7use std::io::{self, Cursor};
8
9use blowfish::Blowfish;
10
11use super::filter::{BlowfishReader, BlowfishWriter, blowfish::BLOCK_SIZE};
12use super::packet::{Packet, RawPacket, PacketConfig, PacketConfigError};
13use super::bundle::Bundle;
14
15
16const FRAGMENT_TIMEOUT: Duration = Duration::from_secs(10);
18
19
20#[derive(Clone)]
30pub struct BundleSocket {
31 shared: Arc<Shared>,
33}
34
35struct Shared {
37 addr: SocketAddrV4,
39 socket: UdpSocket,
41 mutable: Mutex<SharedMutable>,
43}
44
45struct SharedMutable {
47 fragments: HashMap<(SocketAddr, u32), BundleFragments>,
49 next_sequence_num: u32,
51 channels: HashMap<SocketAddr, Channel>,
54 encryption_packet: Box<RawPacket>,
56 rejected_packets: Vec<(SocketAddr, Box<Packet>, PacketRejectionError)>,
58}
59
60impl BundleSocket {
61
62 pub fn new(addr: SocketAddrV4) -> io::Result<Self> {
64
65 let socket = UdpSocket::bind(SocketAddr::V4(addr))?;
66
67 Ok(Self {
68 shared: Arc::new(Shared {
69 addr,
70 socket,
71 mutable: Mutex::new(SharedMutable {
72 fragments: HashMap::new(),
73 next_sequence_num: 0,
74 channels: HashMap::new(),
75 encryption_packet: Box::new(RawPacket::new()),
76 rejected_packets: Vec::new(),
77 }),
78 }),
79 })
80
81 }
82
83 #[inline]
85 pub fn addr(&self) -> SocketAddrV4 {
86 self.shared.addr
87 }
88
89 pub fn set_channel(&mut self, addr: SocketAddr, blowfish: Arc<Blowfish>) {
93 self.shared.mutable.lock().unwrap()
94 .channels.insert(addr, Channel::new(blowfish));
95 }
96
97 pub fn send(&mut self, bundle: &mut Bundle, to: SocketAddr) -> io::Result<usize> {
104
105 if bundle.is_empty() {
107 return Ok(0)
108 }
109
110 let mut mutable = self.shared.mutable.lock().unwrap();
112 let SharedMutable {
113 next_sequence_num,
114 channels,
115 encryption_packet,
116 ..
117 } = &mut *mutable;
118
119 let mut channel = channels.get_mut(&to);
121
122 let sequence_first_num = *next_sequence_num;
124 *next_sequence_num = next_sequence_num.checked_add(bundle.len() as u32).expect("sequence num overflow");
125 let sequence_last_num = *next_sequence_num - 1;
126
127 let mut packet_config = PacketConfig::new();
129
130 if let Some(channel) = channel.as_deref_mut() {
132
133 packet_config.set_on_channel(true);
134 packet_config.set_reliable(true);
135
136 if packet_config.cumulative_ack().is_some() {
139 channel.take_auto_ack();
140 }
141
142 }
143
144 if sequence_last_num > sequence_first_num {
146 packet_config.set_sequence_range(sequence_first_num, sequence_last_num);
147 }
148
149 let mut size = 0;
150 let mut sequence_num = sequence_first_num;
151
152 for packet in bundle.packets_mut() {
153
154 if sequence_num == sequence_last_num {
156 if let Some(channel) = channel.as_mut() {
157 if let Some(num) = channel.get_cumulative_ack_exclusive() {
158 packet_config.set_cumulative_ack(num);
159 } else {
160 packet_config.clear_cumulative_ack();
161 }
162 }
163 }
164
165 packet_config.set_sequence_num(sequence_num);
167 packet.write_config(&mut packet_config);
168
169 let raw_packet;
171
172 if let Some(channel) = channel.as_deref_mut() {
173
174 channel.add_sent_ack(sequence_num);
175
176 encrypt_packet(packet.raw(), &channel.blowfish, &mut **encryption_packet);
177 raw_packet = &**encryption_packet;
178
179 } else {
180 raw_packet = packet.raw()
181 }
182
183 size += self.shared.socket.send_to(raw_packet.data(), to)?;
186 sequence_num += 1;
187
188 }
189
190 Ok(size)
191
192 }
193
194 pub fn recv(&mut self) -> io::Result<Option<Bundle>> {
200
201 let mut packet = Packet::new_boxed();
202 let (len, addr) = self.shared.socket.recv_from(packet.raw_mut().raw_data_mut())?;
203
204 packet.raw_mut().set_data_len(len);
206
207 let mut mutable = self.shared.mutable.lock().unwrap();
210 let SharedMutable {
211 channels,
212 fragments,
213 rejected_packets,
214 ..
215 } = &mut *mutable;
216
217 let mut channel = channels.get_mut(&addr);
219
220 if let Some(channel) = channel.as_deref_mut() {
223 match decrypt_packet(&packet, &channel.blowfish) {
224 Ok(clear_packet) => packet = clear_packet,
225 Err(()) => {
226 mutable.rejected_packets.push((addr, packet, PacketRejectionError::InvalidEncryption));
227 return Ok(None);
228 }
229 }
230 }
231
232 let len = packet.raw().data_len();
234
235 let mut packet_config = PacketConfig::new();
237 if let Err(error) = packet.read_config(len, &mut packet_config) {
238 mutable.rejected_packets.push((addr, packet, PacketRejectionError::Config(error)));
239 return Ok(None);
240 }
241
242 if let Some(channel) = channel.as_deref_mut() {
244
245 if packet_config.reliable() {
247 channel.add_received_ack(packet_config.sequence_num());
248 channel.set_auto_ack();
249 }
250
251 if let Some(ack) = packet_config.cumulative_ack() {
252 channel.remove_cumulative_ack(ack);
253 }
254
255 }
256
257 if packet_config.unk_1000().is_none() {
261
262 let instant = Instant::now();
263
264 match packet_config.sequence_range() {
265 Some((first_num, last_num)) if last_num > first_num => {
267
268 let num = packet_config.sequence_num();
269
270 match fragments.entry((addr, first_num)) {
271 hash_map::Entry::Occupied(mut o) => {
272
273 if o.get().is_old(instant, FRAGMENT_TIMEOUT) {
277 rejected_packets.extend(o.get_mut().drain()
278 .map(|packet| (addr, packet, PacketRejectionError::TimedOut)));
279 }
280
281 o.get_mut().set(num, packet);
282
283 if o.get().is_full() {
285 return Ok(Some(o.remove().into_bundle()));
286 }
287
288 },
289 hash_map::Entry::Vacant(v) => {
290 let mut fragments = BundleFragments::new(last_num - first_num + 1);
291 fragments.set(num, packet);
292 v.insert(fragments);
293 }
294 }
295
296 }
297 _ => {
299 return Ok(Some(Bundle::with_single(packet)));
300 }
301 }
302
303 }
304
305 Ok(None)
307
308 }
309
310 pub fn send_auto_ack(&mut self) {
312
313 let mut mutable = self.shared.mutable.lock().unwrap();
314 let SharedMutable {
315 channels,
316 encryption_packet,
317 ..
318 } = &mut *mutable;
319
320 for (addr, channel) in channels {
321 if channel.take_auto_ack() {
322
323 let mut packet_config = PacketConfig::new();
324 let ack = channel.get_cumulative_ack_exclusive().expect("incoherent");
325 packet_config.set_sequence_num(ack);
326 packet_config.set_cumulative_ack(ack);
327 packet_config.set_on_channel(true);
328 packet_config.set_unk_1000(0);
329
330 let mut packet = Packet::new_boxed();
331 packet.write_config(&mut packet_config);
332
333 encrypt_packet(packet.raw(), &channel.blowfish, &mut **encryption_packet);
334
335 self.shared.socket.send_to(encryption_packet.data(), *addr).unwrap();
337
338 }
339 }
340
341 }
342
343 pub fn take_rejected_packets(&mut self) -> Vec<(SocketAddr, Box<Packet>, PacketRejectionError)> {
345
346 let mut mutable = self.shared.mutable.lock().unwrap();
347 let SharedMutable {
348 fragments,
349 rejected_packets,
350 ..
351 } = &mut *mutable;
352
353 let instant = Instant::now();
355 fragments.retain(|(addr, _), fragments| {
356 if fragments.is_old(instant, FRAGMENT_TIMEOUT) {
357 rejected_packets.extend(fragments.drain()
358 .map(|packet| (*addr, packet, PacketRejectionError::TimedOut)));
359 false
360 } else {
361 true
362 }
363 });
364
365 std::mem::take(rejected_packets)
367
368 }
369
370}
371
372
373const ENCRYPTION_MAGIC: [u8; 4] = 0xDEADBEEFu32.to_le_bytes();
375const ENCRYPTION_FOOTER_LEN: usize = ENCRYPTION_MAGIC.len() + 1;
377
378
379fn decrypt_packet(packet: &Packet, bf: &Blowfish) -> Result<Box<Packet>, ()> {
385
386 let len = packet.raw().data_len();
387
388 let mut clear_packet = Packet::new_boxed();
390 clear_packet.raw_mut().set_data_len(len);
391
392 let src = packet.raw().body();
396 let dst = clear_packet.raw_mut().body_mut();
397
398 if src.len() % BLOCK_SIZE != 0 || src.len() < ENCRYPTION_FOOTER_LEN {
402 return Err(())
403 }
404
405 io::copy(
407 &mut BlowfishReader::new(Cursor::new(src), &bf),
408 &mut Cursor::new(&mut *dst),
409 ).unwrap();
410
411 let wastage_begin = src.len() - 1;
412 let magic_begin = wastage_begin - 4;
413
414 if &dst[magic_begin..wastage_begin] != &ENCRYPTION_MAGIC {
416 return Err(())
417 }
418
419 let wastage = dst[wastage_begin];
422 assert!(wastage <= BLOCK_SIZE as u8, "temporary check that wastage is not greater than block size");
423
424 clear_packet.raw_mut().set_data_len(len - wastage as usize - ENCRYPTION_MAGIC.len());
425 clear_packet.raw_mut().write_prefix(packet.raw().read_prefix());
427
428 Ok(clear_packet)
429
430}
431
432
433fn encrypt_packet(src_packet: &RawPacket, bf: &Blowfish, dst_packet: &mut RawPacket) {
439
440 let mut len = src_packet.body_len() + ENCRYPTION_FOOTER_LEN;
442
443 let padding = (BLOCK_SIZE - (len % BLOCK_SIZE)) % BLOCK_SIZE;
445 len += padding;
446
447 let mut clear_data = Vec::from(src_packet.body());
449 clear_data.reserve_exact(padding + ENCRYPTION_FOOTER_LEN);
450 clear_data.extend_from_slice(&[0u8; BLOCK_SIZE - 1][..padding]); clear_data.extend_from_slice(&ENCRYPTION_MAGIC); clear_data.push(padding as u8 + 1); debug_assert_eq!(clear_data.len(), len, "incoherent length");
455 debug_assert_eq!(clear_data.len() % 8, 0, "data not padded as expected");
456
457 dst_packet.set_data_len(clear_data.len() + 4);
459
460 io::copy(
462 &mut Cursor::new(&clear_data[..]),
463 &mut BlowfishWriter::new(Cursor::new(dst_packet.body_mut()), bf),
464 ).unwrap();
465
466 dst_packet.write_prefix(src_packet.read_prefix());
468
469}
470
471
472#[derive(Debug)]
474pub struct Channel {
475 blowfish: Arc<Blowfish>,
477 sent_acks: Vec<u32>,
480 received_acks: Vec<u32>,
482 auto_ack: bool,
484}
485
486impl Channel {
487
488 fn new(blowfish: Arc<Blowfish>) -> Self {
489 Self {
490 blowfish,
491 sent_acks: Vec::new(),
492 received_acks: Vec::new(),
493 auto_ack: false,
494 }
495 }
496
497 fn add_sent_ack(&mut self, sequence_num: u32) {
498
499 debug_assert!(
500 self.sent_acks.is_empty() || *self.sent_acks.last().unwrap() < sequence_num,
501 "sequence number is not ordered"
502 );
503
504 self.sent_acks.push(sequence_num);
505 println!("[AFTER ADD] sent_acks: {:?}", self.sent_acks);
506
507 }
508
509 fn remove_cumulative_ack(&mut self, ack: u32) {
510
511 let discard_offset = match self.sent_acks.binary_search(&ack) {
512 Ok(index) => index,
513 Err(index) => index,
514 };
515
516 self.sent_acks.drain(..discard_offset);
517 println!("[AFTER REM] sent_acks: {:?}", self.sent_acks);
518
519 }
520
521 fn add_received_ack(&mut self, sequence_num: u32) {
522
523 match self.received_acks.binary_search(&sequence_num) {
524 Ok(_) => {
525 }
527 Err(index) => {
528 self.received_acks.insert(index, sequence_num);
529 }
530 }
531
532 println!("[AFTER ADD] received_acks: {:?}", self.received_acks);
533
534 }
535
536 #[inline]
537 fn set_auto_ack(&mut self) {
538 self.auto_ack = true;
539 }
540
541 #[inline]
543 fn take_auto_ack(&mut self) -> bool {
544 std::mem::replace(&mut self.auto_ack, false)
545 }
546
547 fn get_cumulative_ack(&mut self) -> Option<u32> {
549
550 let first_ack = *self.received_acks.get(0)?;
551 let mut cumulative_ack = first_ack;
552
553 for &sequence_num in &self.received_acks[1..] {
554 if sequence_num == cumulative_ack + 1 {
555 cumulative_ack += 1;
556 } else {
557 break
558 }
559 }
560
561 if cumulative_ack > first_ack {
562 let diff = cumulative_ack - first_ack;
563 self.received_acks.drain(..diff as usize);
564 }
565
566 Some(cumulative_ack)
567
568 }
569
570 #[inline]
571 fn get_cumulative_ack_exclusive(&mut self) -> Option<u32> {
572 self.get_cumulative_ack().map(|n| n + 1)
573 }
574
575}
576
577struct BundleFragments {
579 fragments: Vec<Option<Box<Packet>>>, seq_count: u32,
581 last_update: Instant,
582}
583
584impl BundleFragments {
585
586 fn new(seq_len: u32) -> Self {
588 Self {
589 fragments: (0..seq_len).map(|_| None).collect(),
590 seq_count: 0,
591 last_update: Instant::now()
592 }
593 }
594
595 fn drain(&mut self) -> impl Iterator<Item = Box<Packet>> {
597 self.seq_count = 0;
598 std::mem::take(&mut self.fragments)
599 .into_iter()
600 .filter_map(|slot| slot)
601 }
602
603 fn set(&mut self, num: u32, packet: Box<Packet>) {
605 let frag = &mut self.fragments[num as usize];
606 if frag.is_none() {
607 self.seq_count += 1;
608 }
609 self.last_update = Instant::now();
610 *frag = Some(packet);
611 }
612
613 #[inline]
614 fn is_old(&self, instant: Instant, timeout: Duration) -> bool {
615 instant - self.last_update > timeout
616 }
617
618 #[inline]
619 fn is_full(&self) -> bool {
620 self.seq_count as usize == self.fragments.len()
621 }
622
623 #[inline]
625 fn into_bundle(self) -> Bundle {
626 assert!(self.is_full());
627 let packets = self.fragments.into_iter()
628 .map(|o| o.unwrap())
629 .collect();
630 Bundle::with_multiple(packets)
631 }
632
633}
634
635
636#[derive(Debug, Clone, thiserror::Error)]
638pub enum PacketRejectionError {
639 #[error("timed out")]
642 TimedOut,
643 #[error("invalid encryption")]
645 InvalidEncryption,
646 #[error("sync error: {0}")]
648 Config(#[from] PacketConfigError),
649}