wgtk/net/
socket.rs

1//! Providing an bundle-oriented socket, backed by an UDP socket.
2
3use std::net::{SocketAddr, SocketAddrV4, UdpSocket};
4use std::collections::{HashMap, hash_map};
5use std::time::{Duration, Instant};
6use std::sync::{Arc, Mutex};
7use std::io::{self, Cursor};
8
9use blowfish::Blowfish;
10
11use super::filter::{BlowfishReader, BlowfishWriter, blowfish::BLOCK_SIZE};
12use super::packet::{Packet, RawPacket, PacketConfig, PacketConfigError};
13use super::bundle::Bundle;
14
15
16/// The (currently hardcoded) timeout on bundle fragments.
17const FRAGMENT_TIMEOUT: Duration = Duration::from_secs(10);
18
19
20/// A socket providing interface for sending and receiving bundles of elements, backed by
21/// an UDP server with support for blowfish channel encryption. This socket is blocking
22/// on sends and receives.
23/// 
24/// It also provides fragmentation support when sending bundles that contains more than 
25/// one packet, channel blowfish encryption and packet acknowledgment.
26/// 
27/// This socket handle is actually just a shared pointer to shared data, it can be cloned
28/// as needed and used in multiple threads at the same time.
29#[derive(Clone)]
30pub struct BundleSocket {
31    /// Shared data.
32    shared: Arc<Shared>,
33}
34
35/// A reference counted shared underlying socket data.
36struct Shared {
37    /// Bound address for UDP server.
38    addr: SocketAddrV4,
39    /// The socket used for sending and receiving UDP packets.
40    socket: UdpSocket,
41    /// The mutable part of the shared data, behind a mutex lock.
42    mutable: Mutex<SharedMutable>,
43}
44
45/// Mutable shared socket data.
46struct SharedMutable {
47    /// Bundle fragments tracking.
48    fragments: HashMap<(SocketAddr, u32), BundleFragments>,
49    /// The next sequence ID to use for bundles.
50    next_sequence_num: u32,
51    /// Registered channels on the app that defines a particular blowfish
52    /// key for packet encryption and decryption.
53    channels: HashMap<SocketAddr, Channel>,
54    /// A raw packet used as a temporary buffer for blowfish encryption.
55    encryption_packet: Box<RawPacket>,
56    /// List of rejected packets.
57    rejected_packets: Vec<(SocketAddr, Box<Packet>, PacketRejectionError)>,
58}
59
60impl BundleSocket {
61
62    /// Create a new socket bound to the given address.
63    pub fn new(addr: SocketAddrV4) -> io::Result<Self> {
64
65        let socket = UdpSocket::bind(SocketAddr::V4(addr))?;
66        
67        Ok(Self {
68            shared: Arc::new(Shared { 
69                addr, 
70                socket, 
71                mutable: Mutex::new(SharedMutable {
72                    fragments: HashMap::new(),
73                    next_sequence_num: 0,
74                    channels: HashMap::new(),
75                    encryption_packet: Box::new(RawPacket::new()),
76                    rejected_packets: Vec::new(),
77                }),
78            }),
79        })
80
81    }
82
83    /// Get the bind address of this socket.
84    #[inline]
85    pub fn addr(&self) -> SocketAddrV4 {
86        self.shared.addr
87    }
88
89    /// Associate a new channel to the given address with the given blowfish
90    /// encryption. This blowfish encryption will be used for all 
91    /// transaction to come with this given socket address.
92    pub fn set_channel(&mut self, addr: SocketAddr, blowfish: Arc<Blowfish>) {
93        self.shared.mutable.lock().unwrap()
94            .channels.insert(addr, Channel::new(blowfish));
95    }
96
97    /// Send a bundle to a given address. Note that the bundle is finalized by
98    /// this method with the internal sequence id.
99    /// 
100    /// Note that the net data is guaranteed to be untouched by this function,
101    /// this includes prefix, flags up to the footer. Bytes beyond this limit
102    /// might be modified in case of channel encryption.
103    pub fn send(&mut self, bundle: &mut Bundle, to: SocketAddr) -> io::Result<usize> {
104
105        // Do nothing if bundle is empty.
106        if bundle.is_empty() {
107            return Ok(0)
108        }
109
110        // NOTE: This may potentially block is a received packet is being processed.
111        let mut mutable = self.shared.mutable.lock().unwrap();
112        let SharedMutable {
113            next_sequence_num,
114            channels,
115            encryption_packet,
116            ..
117        } = &mut *mutable;
118
119        // Get a potential reference to the channel this address is linked to.
120        let mut channel = channels.get_mut(&to);
121
122        // Compute first and last sequence num and directly update next sequence num.
123        let sequence_first_num = *next_sequence_num;
124        *next_sequence_num = next_sequence_num.checked_add(bundle.len() as u32).expect("sequence num overflow");
125        let sequence_last_num = *next_sequence_num - 1;
126
127        // Create a common packet config for all the bundle.
128        let mut packet_config = PacketConfig::new();
129
130        // When on channel, we set appropriate flags and values.
131        if let Some(channel) = channel.as_deref_mut() {
132
133            packet_config.set_on_channel(true);
134            packet_config.set_reliable(true);
135            
136            // If we send a cumulative ack, just take auto ack to avoid resending 
137            // it automatically.
138            if packet_config.cumulative_ack().is_some() {
139                channel.take_auto_ack();
140            }
141
142        }
143
144        // If multi-packet bundle, set sequence range.
145        if sequence_last_num > sequence_first_num {
146            packet_config.set_sequence_range(sequence_first_num, sequence_last_num);
147        }
148
149        let mut size = 0;
150        let mut sequence_num = sequence_first_num;
151
152        for packet in bundle.packets_mut() {
153
154            // Only the last packet has a cumulative ack.
155            if sequence_num == sequence_last_num {
156                if let Some(channel) = channel.as_mut() {
157                    if let Some(num) = channel.get_cumulative_ack_exclusive() {
158                        packet_config.set_cumulative_ack(num);
159                    } else {
160                        packet_config.clear_cumulative_ack();
161                    }
162                }
163            }
164
165            // Set sequence number and sync data.
166            packet_config.set_sequence_num(sequence_num);
167            packet.write_config(&mut packet_config);
168
169            // Reference to the actual raw packet to send.
170            let raw_packet;
171
172            if let Some(channel) = channel.as_deref_mut() {
173
174                channel.add_sent_ack(sequence_num);
175
176                encrypt_packet(packet.raw(), &channel.blowfish, &mut **encryption_packet);
177                raw_packet = &**encryption_packet;
178
179            } else {
180                raw_packet = packet.raw()
181            }
182            
183            // println!("Sending {:X}", BytesFmt(raw_packet.data()));
184
185            size += self.shared.socket.send_to(raw_packet.data(), to)?;
186            sequence_num += 1;
187
188        }
189
190        Ok(size)
191
192    }
193
194    /// Blocking receive of a packet, if a bundle can be constructed it is returned, if
195    /// not, none is returned instead. If the packet is rejected for any reason listed
196    /// in [`PacketRejectionError`], none is also returned but the packet is internally
197    /// queued and can later be retrieve with the error using 
198    /// [`Self::take_rejected_packets()`].
199    pub fn recv(&mut self) -> io::Result<Option<Bundle>> {
200
201        let mut packet = Packet::new_boxed();
202        let (len, addr) = self.shared.socket.recv_from(packet.raw_mut().raw_data_mut())?;
203
204        // Adjust the data length depending on what have been received.
205        packet.raw_mut().set_data_len(len);
206
207        // NOTE: We lock only once we received the packet, so it's not blocking any other
208        // handle to this socket that want to send bundles.
209        let mut mutable = self.shared.mutable.lock().unwrap();
210        let SharedMutable {
211            channels,
212            fragments,
213            rejected_packets,
214            ..
215        } = &mut *mutable;
216
217        // Get a potential reference to the channel this address is linked to.
218        let mut channel = channels.get_mut(&addr);
219
220        // If the address is linked to a channel, we need to decrypt it according the 
221        // channel's blowfish key.
222        if let Some(channel) = channel.as_deref_mut() {
223            match decrypt_packet(&packet, &channel.blowfish) {
224                Ok(clear_packet) => packet = clear_packet,
225                Err(()) => {
226                    mutable.rejected_packets.push((addr, packet, PacketRejectionError::InvalidEncryption));
227                    return Ok(None);
228                }
229            }
230        }
231
232        // Retrieve the real clear-text length after a potential decryption.
233        let len = packet.raw().data_len();
234
235        // TODO: Use thread-local for packet config?
236        let mut packet_config = PacketConfig::new();
237        if let Err(error) = packet.read_config(len, &mut packet_config) {
238            mutable.rejected_packets.push((addr, packet, PacketRejectionError::Config(error)));
239            return Ok(None);
240        }
241
242        // Again, if we are in a channel, we handle packet acknowledgment.
243        if let Some(channel) = channel.as_deref_mut() {
244
245            // If packet is reliable, take its ack number and store it for future acknowledging.
246            if packet_config.reliable() {
247                channel.add_received_ack(packet_config.sequence_num());
248                channel.set_auto_ack();
249            }
250
251            if let Some(ack) = packet_config.cumulative_ack() {
252                channel.remove_cumulative_ack(ack);
253            }
254
255        }
256
257        // We can observe that packets with the flag 0x1000 are only used
258        // for auto acking with sometimes duplicated data that is sent
259        // just after. If it become a problem this check can be removed. 
260        if packet_config.unk_1000().is_none() {
261            
262            let instant = Instant::now();
263
264            match packet_config.sequence_range() {
265                // Only if there is a range and this range is not a single num.
266                Some((first_num, last_num)) if last_num > first_num => {
267
268                    let num = packet_config.sequence_num();
269
270                    match fragments.entry((addr, first_num)) {
271                        hash_map::Entry::Occupied(mut o) => {
272
273                            // If this fragments is too old, timeout every packet in it
274                            // and start again with the packet.
275                            // FIXME: Maybe dumb?
276                            if o.get().is_old(instant, FRAGMENT_TIMEOUT) {
277                                rejected_packets.extend(o.get_mut().drain()
278                                    .map(|packet| (addr, packet, PacketRejectionError::TimedOut)));
279                            }
280
281                            o.get_mut().set(num, packet);
282
283                            // When all fragments are collected, remove entry and return.
284                            if o.get().is_full() {
285                                return Ok(Some(o.remove().into_bundle()));
286                            }
287
288                        },
289                        hash_map::Entry::Vacant(v) => {
290                            let mut fragments = BundleFragments::new(last_num - first_num + 1);
291                            fragments.set(num, packet);
292                            v.insert(fragments);
293                        }
294                    }
295
296                }
297                // Not sequence range in the packet, create a bundle only with it.
298                _ => {
299                    return Ok(Some(Bundle::with_single(packet)));
300                }
301            }
302
303        }
304
305        // No error but no full bundle received.
306        Ok(None)
307
308    }
309
310    /// Send all auto packet acknowledgments.
311    pub fn send_auto_ack(&mut self) {
312
313        let mut mutable = self.shared.mutable.lock().unwrap();
314        let SharedMutable {
315            channels,
316            encryption_packet,
317            ..
318        } = &mut *mutable;
319
320        for (addr, channel) in channels {
321            if channel.take_auto_ack() {
322                
323                let mut packet_config = PacketConfig::new();
324                let ack = channel.get_cumulative_ack_exclusive().expect("incoherent");
325                packet_config.set_sequence_num(ack);
326                packet_config.set_cumulative_ack(ack);
327                packet_config.set_on_channel(true);
328                packet_config.set_unk_1000(0);
329
330                let mut packet = Packet::new_boxed();
331                packet.write_config(&mut packet_config);
332
333                encrypt_packet(packet.raw(), &channel.blowfish, &mut **encryption_packet);
334
335                // println!("Sending auto ack {:X}", BytesFmt(self.encryption_packet.data()));
336                self.shared.socket.send_to(encryption_packet.data(), *addr).unwrap();
337
338            }
339        }
340
341    }
342
343    /// Take the vector of all rejected packets and the rejection reason.
344    pub fn take_rejected_packets(&mut self) -> Vec<(SocketAddr, Box<Packet>, PacketRejectionError)> {
345
346        let mut mutable = self.shared.mutable.lock().unwrap();
347        let SharedMutable {
348            fragments,
349            rejected_packets,
350            ..
351        } = &mut *mutable;
352
353        // Before returning the vector, take all timed out fragments.
354        let instant = Instant::now();
355        fragments.retain(|(addr, _), fragments| {
356            if fragments.is_old(instant, FRAGMENT_TIMEOUT) {
357                rejected_packets.extend(fragments.drain()
358                    .map(|packet| (*addr, packet, PacketRejectionError::TimedOut)));
359                false
360            } else {
361                true
362            }
363        });
364
365        // NOTE: We just take the vector, so mutable data is not locked for too long.
366        std::mem::take(rejected_packets)
367
368    }
369
370}
371
372
373/// Encryption magic, 0xDEADBEEF in little endian.
374const ENCRYPTION_MAGIC: [u8; 4] = 0xDEADBEEFu32.to_le_bytes();
375/// Encryption footer length, 1 byte for wastage count + 4 bytes magic.
376const ENCRYPTION_FOOTER_LEN: usize = ENCRYPTION_MAGIC.len() + 1;
377
378
379/// Decrypt a packet of a given length with a blowfish key.
380/// 
381/// This returns an empty error if the encryption is invalid.
382/// If successful the clear packet is returned with its size, the size can then
383/// be used to synchronize the packet's state to its data.
384fn decrypt_packet(packet: &Packet, bf: &Blowfish) -> Result<Box<Packet>, ()> {
385
386    let len = packet.raw().data_len();
387
388    // Create a packet that have the same length as input packet.
389    let mut clear_packet = Packet::new_boxed();
390    clear_packet.raw_mut().set_data_len(len);
391
392    // Decrypt the incoming packet into the new clear packet.
393    // We don't need to set the length yet because this packet 
394    // will be synchronized just after.
395    let src = packet.raw().body();
396    let dst = clear_packet.raw_mut().body_mut();
397    
398    // Note that src and dst have the same length, thanks to blowfish encryption.
399    // Then we can already check the length and ensures that it is a multiple of
400    // blowfish block size *and* can contain the wastage and encryption magic.
401    if src.len() % BLOCK_SIZE != 0 || src.len() < ENCRYPTION_FOOTER_LEN {
402        return Err(())
403    }
404
405    // Unwrapping because we know that source/destination have the same length.
406    io::copy(
407        &mut BlowfishReader::new(Cursor::new(src), &bf), 
408        &mut Cursor::new(&mut *dst),
409    ).unwrap();
410
411    let wastage_begin = src.len() - 1;
412    let magic_begin = wastage_begin - 4;
413
414    // Check invalid magic.
415    if &dst[magic_begin..wastage_begin] != &ENCRYPTION_MAGIC {
416        return Err(())
417    }
418
419    // Get the wastage count and compute the packet's length.
420    // Note that wastage count also it self length.
421    let wastage = dst[wastage_begin];
422    assert!(wastage <= BLOCK_SIZE as u8, "temporary check that wastage is not greater than block size");
423
424    clear_packet.raw_mut().set_data_len(len - wastage as usize - ENCRYPTION_MAGIC.len());
425    // Copy the prefix directly because it is clear.
426    clear_packet.raw_mut().write_prefix(packet.raw().read_prefix());
427
428    Ok(clear_packet)
429
430}
431
432
433/// Encrypt source packet with the given blowfish key and write it to the destination
434/// raw packet. Everything except the packet prefix is encrypted, and the destination
435/// packet will have a size that is a multiple of blowfish's block size (8). The clear
436/// data is also padded to block size, but with additional data at the end: encryption
437/// signature (0xDEADBEEF in little endian) and the wastage count + 1 on the last byte.
438fn encrypt_packet(src_packet: &RawPacket, bf: &Blowfish, dst_packet: &mut RawPacket) {
439    
440    // Get the minimum, unpadded length of this packet with encryption footer appended to it.
441    let mut len = src_packet.body_len() + ENCRYPTION_FOOTER_LEN;
442
443    // The wastage amount is basically the padding + 1 for the wastage itself.
444    let padding = (BLOCK_SIZE - (len % BLOCK_SIZE)) % BLOCK_SIZE;
445    len += padding;
446
447    // Clone the packet data into a new vec and append the padding and the footer.
448    let mut clear_data = Vec::from(src_packet.body());
449    clear_data.reserve_exact(padding + ENCRYPTION_FOOTER_LEN);
450    clear_data.extend_from_slice(&[0u8; BLOCK_SIZE - 1][..padding]); // Padding
451    clear_data.extend_from_slice(&ENCRYPTION_MAGIC); // Magic
452    clear_data.push(padding as u8 + 1); // Wastage count (+1 for it self size)
453
454    debug_assert_eq!(clear_data.len(), len, "incoherent length");
455    debug_assert_eq!(clear_data.len() % 8, 0, "data not padded as expected");
456    
457    // +4 for the prefix.
458    dst_packet.set_data_len(clear_data.len() + 4);
459
460    // Unwrapping because we know that source/destination have the same length.
461    io::copy(
462        &mut Cursor::new(&clear_data[..]), 
463        &mut BlowfishWriter::new(Cursor::new(dst_packet.body_mut()), bf),
464    ).unwrap();
465    
466    // Copy the prefix directly because it is clear.
467    dst_packet.write_prefix(src_packet.read_prefix());
468
469}
470
471
472/// Represent a channel between the app and a client with specific socket address.
473#[derive(Debug)]
474pub struct Channel {
475    /// The blowfish key used for encryption of this channel.
476    blowfish: Arc<Blowfish>,
477    /// The list of acks that are pending for completion. They should be ordered
478    /// in the vector, so a simple binary search is enough.
479    sent_acks: Vec<u32>,
480    /// The list of received acks, it's used for sending. 
481    received_acks: Vec<u32>,
482    // /// Set to true when an ack should be sent even if not bundle is set.
483    auto_ack: bool,
484}
485
486impl Channel {
487
488    fn new(blowfish: Arc<Blowfish>) -> Self {
489        Self {
490            blowfish,
491            sent_acks: Vec::new(),
492            received_acks: Vec::new(),
493            auto_ack: false,
494        }
495    }
496
497    fn add_sent_ack(&mut self, sequence_num: u32) {
498
499        debug_assert!(
500            self.sent_acks.is_empty() || *self.sent_acks.last().unwrap() < sequence_num,
501            "sequence number is not ordered"
502        );
503
504        self.sent_acks.push(sequence_num);
505        println!("[AFTER ADD] sent_acks: {:?}", self.sent_acks);
506
507    }
508
509    fn remove_cumulative_ack(&mut self, ack: u32) {
510        
511        let discard_offset = match self.sent_acks.binary_search(&ack) {
512            Ok(index) => index,
513            Err(index) => index,
514        };
515
516        self.sent_acks.drain(..discard_offset);
517        println!("[AFTER REM] sent_acks: {:?}", self.sent_acks);
518
519    }
520
521    fn add_received_ack(&mut self, sequence_num: u32) {
522
523        match self.received_acks.binary_search(&sequence_num) {
524            Ok(_) => {
525                // Maybe an error to receive the same ack twice?
526            }
527            Err(index) => {
528                self.received_acks.insert(index, sequence_num);
529            }
530        }
531
532        println!("[AFTER ADD] received_acks: {:?}", self.received_acks);
533
534    }
535
536    #[inline]
537    fn set_auto_ack(&mut self) {
538        self.auto_ack = true;
539    }
540
541    /// Take the auto ack and disable it anyway.
542    #[inline]
543    fn take_auto_ack(&mut self) -> bool {
544        std::mem::replace(&mut self.auto_ack, false)
545    }
546
547    /// Return the last ack that is part of an chain.
548    fn get_cumulative_ack(&mut self) -> Option<u32> {
549
550        let first_ack = *self.received_acks.get(0)?;
551        let mut cumulative_ack = first_ack;
552
553        for &sequence_num in &self.received_acks[1..] {
554            if sequence_num == cumulative_ack + 1 {
555                cumulative_ack += 1;
556            } else {
557                break
558            }
559        }
560
561        if cumulative_ack > first_ack {
562            let diff = cumulative_ack - first_ack;
563            self.received_acks.drain(..diff as usize);
564        }
565
566        Some(cumulative_ack)
567
568    }
569
570    #[inline]
571    fn get_cumulative_ack_exclusive(&mut self) -> Option<u32> {
572        self.get_cumulative_ack().map(|n| n + 1)
573    }
574
575}
576
577/// Internal structure to keep fragments from a given sequence.
578struct BundleFragments {
579    fragments: Vec<Option<Box<Packet>>>,  // Using boxes to avoid moving huge structures.
580    seq_count: u32,
581    last_update: Instant,
582}
583
584impl BundleFragments {
585
586    /// Create from sequence length.
587    fn new(seq_len: u32) -> Self {
588        Self {
589            fragments: (0..seq_len).map(|_| None).collect(),
590            seq_count: 0,
591            last_update: Instant::now()
592        }
593    }
594
595    /// This this fragments packets and reset internal count to zero.
596    fn drain(&mut self) -> impl Iterator<Item = Box<Packet>> {
597        self.seq_count = 0;
598        std::mem::take(&mut self.fragments)
599            .into_iter()
600            .filter_map(|slot| slot)
601    }
602
603    /// Set a fragment.
604    fn set(&mut self, num: u32, packet: Box<Packet>) {
605        let frag = &mut self.fragments[num as usize];
606        if frag.is_none() {
607            self.seq_count += 1;
608        }
609        self.last_update = Instant::now();
610        *frag = Some(packet);
611    }
612
613    #[inline]
614    fn is_old(&self, instant: Instant, timeout: Duration) -> bool {
615        instant - self.last_update > timeout
616    }
617
618    #[inline]
619    fn is_full(&self) -> bool {
620        self.seq_count as usize == self.fragments.len()
621    }
622
623    /// Convert this structure to a bundle, **safe to call only if `is_full() == true`**.
624    #[inline]
625    fn into_bundle(self) -> Bundle {
626        assert!(self.is_full());
627        let packets = self.fragments.into_iter()
628            .map(|o| o.unwrap())
629            .collect();
630        Bundle::with_multiple(packets)
631    }
632
633}
634
635
636///  Kind of error that caused a packet to be rejected from this socket and not received.
637#[derive(Debug, Clone, thiserror::Error)]
638pub enum PacketRejectionError {
639    /// The packet is part of a sequence but no other packets of the sequence have been
640    /// found and therefore no bundle can be reconstructed.
641    #[error("timed out")]
642    TimedOut,
643    /// The packet should be decrypted but it failed.
644    #[error("invalid encryption")]
645    InvalidEncryption,
646    /// The packet could not be synchronized from its data.
647    #[error("sync error: {0}")]
648    Config(#[from] PacketConfigError),
649}