udpsec/
lib.rs

1use std::collections::HashMap;
2use std::net::{UdpSocket, SocketAddr};
3use std::io::{ErrorKind, Error};
4use std::time::{Duration, Instant};
5use std::sync::mpsc;
6
7use rand_core::{OsRng, RngCore};
8use x25519_dalek::{PublicKey, ReusableSecret};
9use wait_not_await::Await;
10
11pub use rand_core;
12pub use x25519_dalek;
13pub use wait_not_await;
14
15mod tests;
16
17pub const CHECKSUM_LENGTH: usize = 4;
18
19/// Maximal amount of bytes you can send with `Socket.send()` method
20/// 
21/// - 1 byte is reserved by the packet type
22/// - `CHECKSUM_LENGTH` bytes are reserved by the `get_checksum()`
23pub const DATAGRAM_MAX_LENGTH: usize = 65506 - CHECKSUM_LENGTH;
24
25fn rand_u32() -> u32 {
26    OsRng::default().next_u32()
27}
28
29fn rand_u8() -> u8 {
30    (rand_u32() % 256) as u8
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34enum Packet {
35    KeyExchangeInit(PublicKey),
36    KeyExchangeDone(PublicKey),
37    Datagram(Vec<u8>)
38}
39
40impl Packet {
41    pub fn to_bytes(&self) -> Vec<u8> {
42        match self {
43            // 0-85
44            Packet::KeyExchangeInit(public_key) => {
45                let mut packet = Vec::with_capacity(33);
46
47                packet.push(rand_u8() % 86);
48                packet.append(&mut public_key.as_bytes().to_vec());
49
50                // Random noise
51                for _ in 0..rand_u8() {
52                    let mut rand = rand_u32();
53
54                    while rand > 0 {
55                        packet.push((rand % 256) as u8);
56
57                        rand /= 256;
58                    }
59                }
60
61                packet
62            },
63
64            // 86-170
65            Packet::KeyExchangeDone(public_key) => {
66                let mut packet = Vec::with_capacity(33);
67
68                packet.push(rand_u8() % 85 + 86);
69                packet.append(&mut public_key.as_bytes().to_vec());
70
71                // Random noise
72                for _ in 0..rand_u8() {
73                    let mut rand = rand_u32();
74
75                    while rand > 0 {
76                        packet.push((rand % 256) as u8);
77
78                        rand /= 256;
79                    }
80                }
81
82                packet
83            },
84
85            // 171-255
86            // datagram size: 65536
87            Packet::Datagram(data) => {
88                [vec![rand_u8() % 85 + 171], data.clone()].concat()
89            }
90        }
91    }
92
93    pub fn from_bytes(bytes: &[u8]) -> Result<Packet, Error> {
94        if bytes.len() > 0 {
95            if bytes[0] < 86 {
96                let mut public_key = [0u8; 32];
97
98                public_key.copy_from_slice(&bytes[1..33]);
99
100                match PublicKey::try_from(public_key) {
101                    Ok(public_key) => Ok(Packet::KeyExchangeInit(public_key)),
102                    Err(_) => Err(Error::new(ErrorKind::InvalidData, "Public key decoding error"))
103                }
104            }
105
106            else if bytes[0] < 171 {
107                let mut public_key = [0u8; 32];
108
109                public_key.copy_from_slice(&bytes[1..33]);
110
111                match PublicKey::try_from(public_key) {
112                    Ok(public_key) => Ok(Packet::KeyExchangeDone(public_key)),
113                    Err(_) => Err(Error::new(ErrorKind::InvalidData, "Public key decoding error"))
114                }
115            }
116
117            else {
118                Ok(Packet::Datagram(bytes[1..].to_vec()))
119            }
120        }
121
122        else {
123            Err(Error::new(ErrorKind::InvalidInput, "Slice is empty"))
124        }
125    }
126}
127
128/// This function performs xor encoding / decoding of the input data
129/// 
130/// Used in both `Socket.encoder` and `Socket.decoder` by default
131pub fn xor_encode(mut data: Vec<u8>, key: &[u8; 32]) -> Vec<u8> {
132    for i in 0..data.len() {
133        data[i] = data[i] ^ key[i % 32];
134    }
135
136    data
137}
138
139/// This function returns input value without any transformations
140/// 
141/// Can be used in both `Socket.encoder` and `Socket.decoder` to avoid datagrams encodings
142pub fn plain_text(data: Vec<u8>, _: &[u8; 32]) -> Vec<u8> {
143    data
144}
145
146/// This functions calculates checksum of the input data
147fn get_checksum(data: &[u8]) -> [u8; CHECKSUM_LENGTH] {
148    let mut checksum = [0; CHECKSUM_LENGTH];
149
150    for i in 0..data.len() {
151        let mut sum = u16::from(checksum[i % CHECKSUM_LENGTH]) + u16::from(data[i]);
152
153        if sum > 255 {
154            sum %= 256;
155        }
156
157        checksum[i % CHECKSUM_LENGTH] = sum as u8;
158    }
159
160    checksum
161}
162
163pub struct Socket {
164    addr: SocketAddr,
165    socket: UdpSocket,
166    secrets: HashMap<SocketAddr, [u8; 32]>,
167    floating_connections: HashMap<SocketAddr, (ReusableSecret, mpsc::Sender<()>)>,
168
169    /// Datagrams encoder
170    pub encoder: Box<dyn Fn(Vec<u8>, &[u8; 32]) -> Vec<u8>>,
171
172    /// Datagrams decoder
173    pub decoder: Box<dyn Fn(Vec<u8>, &[u8; 32]) -> Vec<u8>>
174}
175
176impl Socket {
177    pub fn new(addr: SocketAddr) -> Result<Socket, Error> {
178        match UdpSocket::bind(addr) {
179            Ok(socket) => Ok(Socket {
180                addr,
181                socket,
182                secrets: HashMap::new(),
183                floating_connections: HashMap::new(),
184                encoder: Box::new(xor_encode),
185                decoder: Box::new(xor_encode)
186            }),
187            Err(err) => Err(err)
188        }
189    }
190
191    pub fn from_socket(socket: UdpSocket) -> Result<Socket, Error> {
192        match socket.local_addr() {
193            Ok(addr) => Ok(Socket {
194                socket,
195                addr,
196                secrets: HashMap::new(),
197                floating_connections: HashMap::new(),
198                encoder: Box::new(xor_encode),
199                decoder: Box::new(xor_encode)
200            }),
201            Err(err) => Err(err)
202        }
203    }
204
205    pub fn socket(&self) -> &UdpSocket {
206        &self.socket
207    }
208
209    pub fn addr(&self) -> SocketAddr {
210        self.addr
211    }
212
213    /// Sets function that will encode datagrams before its transferring
214    pub fn set_encoder<T: Fn(Vec<u8>, &[u8; 32]) -> Vec<u8> + 'static>(&mut self, encoder: T) {
215        self.encoder = Box::new(encoder);
216    }
217
218    /// Sets function that will decode received datagrams
219    pub fn set_decoder<T: Fn(Vec<u8>, &[u8; 32]) -> Vec<u8> + 'static>(&mut self, decoder: T) {
220        self.decoder = Box::new(decoder);
221    }
222
223    fn write(&self, addr: SocketAddr, packet: Packet) -> Result<usize, Error> {
224        self.socket.send_to(packet.to_bytes().as_slice(), addr)
225    }
226
227    fn read(&self) -> Result<(SocketAddr, Packet), Error> {
228        let mut buf = [0; 65536];
229
230        match self.socket.recv_from(&mut buf) {
231            Ok((size, from)) => {
232                match Packet::from_bytes(&buf[..size]) {
233                    Ok(packet) => Ok((from, packet)),
234                    Err(err) => Err(err)
235                }
236            },
237            Err(err) => Err(err)
238        }
239    }
240
241    /// Generate shared secret with specified remote address
242    /// 
243    /// ```
244    /// use udpsec::Socket;
245    /// 
246    /// let local_addr = "127.0.0.1:50000".parse().unwrap();
247    /// let remote_addr = "127.0.0.1:50001".parse().unwrap();
248    /// 
249    /// let mut socket_a = Socket::new(local_addr).unwrap();
250    /// let mut socket_b = Socket::new(remote_addr).unwrap();
251    /// 
252    /// // wait_not_await's Await
253    /// socket_a.generate_secret(remote_addr).unwrap().then(|ping| {
254    ///     println!("Ping: {} ms", ping.as_millis());
255    /// });
256    /// 
257    /// socket_b.recv(); // Remote client updates its state in a loop
258    /// 
259    /// while socket_a.shared_secret(remote_addr) == None {
260    ///     socket_a.recv();
261    /// }
262    /// 
263    /// println!("Shared secret (local): {:?}", socket_a.shared_secret(remote_addr).unwrap());
264    /// println!("Shared secret (remote): {:?}", socket_b.shared_secret(local_addr).unwrap());
265    /// ```
266    pub fn generate_secret(&mut self, addr: SocketAddr) -> Result<Await<Duration>, Error> {
267        let (sender, receiver) = mpsc::channel::<()>();
268
269        self.floating_connections.insert(addr, (ReusableSecret::new(OsRng), sender));
270
271        match self.write(addr, Packet::KeyExchangeInit(PublicKey::from(&self.floating_connections.get(&addr).unwrap().0))) {
272            Ok(_) => {
273                let instant = Instant::now();
274
275                Ok(Await::new(move || {
276                    receiver.recv();
277
278                    instant.elapsed()
279                }))
280            },
281            Err(err) => Err(err)
282        }
283    }
284
285    /// Get shared secret with a specified remote address
286    /// 
287    /// See `Socket.generate_secret()` for more details
288    pub fn shared_secret(&self, addr: SocketAddr) -> Option<&[u8; 32]> {
289        self.secrets.get(&addr)
290    }
291
292    /// Send data to remote address
293    /// 
294    /// Returns `false` if data couldn't be sent, or shared secret wasn't generated
295    /// 
296    /// ```
297    /// use udpsec::Socket;
298    /// 
299    /// let local_addr = "127.0.0.1:50002".parse().unwrap();
300    /// let remote_addr = "127.0.0.1:50003".parse().unwrap();
301    /// 
302    /// let mut socket_a = Socket::new(local_addr).unwrap();
303    /// let mut socket_b = Socket::new(remote_addr).unwrap();
304    /// 
305    /// socket_a.generate_secret(remote_addr); // Send KeyExchangeInit to remote client
306    /// 
307    /// socket_b.recv(); // Process KeyExchangeInit packet from local client
308    /// socket_a.recv(); // Receive KeyExchangeDone packet from remote client
309    /// 
310    /// socket_a.send(remote_addr, "Hello, World!".as_bytes().to_vec());
311    /// 
312    /// let received = socket_b.recv().unwrap();
313    /// 
314    /// println!("[{}] {}", received.0, String::from_utf8(received.1).unwrap());
315    /// ```
316    pub fn send(&self, addr: SocketAddr, mut data: Vec<u8>) -> Result<usize, Error> {
317        match self.secrets.get(&addr) {
318            Some(secret) => {
319                data = [get_checksum(data.as_slice()).to_vec(), data].concat();
320                data = (self.encoder)(data, secret);
321                
322                self.write(addr, Packet::Datagram(data))
323            },
324            None => Err(Error::new(ErrorKind::NotConnected, "Current socket doesn't have a shared secret with specified remote address"))
325        }
326    }
327
328    /// Receive data from remote socket
329    /// 
330    /// See `Socket.send()` for more details
331    pub fn recv(&mut self) -> Option<(SocketAddr, Vec<u8>)> {
332        match self.read() {
333            Ok((from, packet)) => {
334                match packet {
335                    Packet::KeyExchangeInit(public_key) => {
336                        let secret = ReusableSecret::new(OsRng);
337
338                        self.secrets.insert(from, *secret.diffie_hellman(&public_key).as_bytes());
339
340                        self.write(from, Packet::KeyExchangeDone(PublicKey::from(&secret)));
341
342                        None
343                    }
344
345                    Packet::KeyExchangeDone(public_key) => {
346                        if let Some((secret, sender)) = self.floating_connections.get(&from) {
347                            self.secrets.insert(from, *secret.diffie_hellman(&public_key).as_bytes());
348
349                            sender.send(());
350                            
351                            self.floating_connections.remove(&from);
352                        }
353
354                        None
355                    }
356
357                    Packet::Datagram(data) => {
358                        // If we have secret key - try to decode data
359                        if let Some(secret) = self.secrets.get(&from) {
360                            let decoded = (self.decoder)(data.clone(), secret);
361
362                            if get_checksum(&decoded[4..]) == &decoded[0..4] {
363                                return Some((from, decoded[4..].to_vec()))
364                            }
365                        }
366
367                        // Otherwise let's check all the others secrets and mb there'll be what we need
368                        // The problem is that if client will change its IP address for some reason - it'll
369                        // use already generated shared secret while we'll not know it. So we need to check
370                        // all the connected clients and move secret from one to another if it'll decode data correctly
371                        let mut found = None;
372
373                        for (remote, shared) in &self.secrets {
374                            let decoded = (self.decoder)(data.clone(), &shared);
375
376                            if get_checksum(&decoded[4..]) == &decoded[0..4] {
377                                found = Some((decoded[4..].to_vec(), remote.clone(), shared.clone()));
378
379                                break;
380                            }
381                        }
382
383                        if let Some((decoded, remote, shared)) = found {
384                            self.secrets.remove(&remote);
385                            self.secrets.insert(from, shared.clone());
386
387                            return Some((from, decoded));
388                        }
389
390                        None
391                    }
392                }
393            },
394            Err(_) => None
395        }
396    }
397}