1use std::collections::HashMap;
2use std::net::{UdpSocket, SocketAddr};
3use std::io::{ErrorKind, Error};
4use std::time::{Duration, Instant};
5use std::sync::mpsc;
6
7use rand_core::{OsRng, RngCore};
8use x25519_dalek::{PublicKey, ReusableSecret};
9use wait_not_await::Await;
10
11pub use rand_core;
12pub use x25519_dalek;
13pub use wait_not_await;
14
15mod tests;
16
17pub const CHECKSUM_LENGTH: usize = 4;
18
19pub const DATAGRAM_MAX_LENGTH: usize = 65506 - CHECKSUM_LENGTH;
24
25fn rand_u32() -> u32 {
26 OsRng::default().next_u32()
27}
28
29fn rand_u8() -> u8 {
30 (rand_u32() % 256) as u8
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34enum Packet {
35 KeyExchangeInit(PublicKey),
36 KeyExchangeDone(PublicKey),
37 Datagram(Vec<u8>)
38}
39
40impl Packet {
41 pub fn to_bytes(&self) -> Vec<u8> {
42 match self {
43 Packet::KeyExchangeInit(public_key) => {
45 let mut packet = Vec::with_capacity(33);
46
47 packet.push(rand_u8() % 86);
48 packet.append(&mut public_key.as_bytes().to_vec());
49
50 for _ in 0..rand_u8() {
52 let mut rand = rand_u32();
53
54 while rand > 0 {
55 packet.push((rand % 256) as u8);
56
57 rand /= 256;
58 }
59 }
60
61 packet
62 },
63
64 Packet::KeyExchangeDone(public_key) => {
66 let mut packet = Vec::with_capacity(33);
67
68 packet.push(rand_u8() % 85 + 86);
69 packet.append(&mut public_key.as_bytes().to_vec());
70
71 for _ in 0..rand_u8() {
73 let mut rand = rand_u32();
74
75 while rand > 0 {
76 packet.push((rand % 256) as u8);
77
78 rand /= 256;
79 }
80 }
81
82 packet
83 },
84
85 Packet::Datagram(data) => {
88 [vec![rand_u8() % 85 + 171], data.clone()].concat()
89 }
90 }
91 }
92
93 pub fn from_bytes(bytes: &[u8]) -> Result<Packet, Error> {
94 if bytes.len() > 0 {
95 if bytes[0] < 86 {
96 let mut public_key = [0u8; 32];
97
98 public_key.copy_from_slice(&bytes[1..33]);
99
100 match PublicKey::try_from(public_key) {
101 Ok(public_key) => Ok(Packet::KeyExchangeInit(public_key)),
102 Err(_) => Err(Error::new(ErrorKind::InvalidData, "Public key decoding error"))
103 }
104 }
105
106 else if bytes[0] < 171 {
107 let mut public_key = [0u8; 32];
108
109 public_key.copy_from_slice(&bytes[1..33]);
110
111 match PublicKey::try_from(public_key) {
112 Ok(public_key) => Ok(Packet::KeyExchangeDone(public_key)),
113 Err(_) => Err(Error::new(ErrorKind::InvalidData, "Public key decoding error"))
114 }
115 }
116
117 else {
118 Ok(Packet::Datagram(bytes[1..].to_vec()))
119 }
120 }
121
122 else {
123 Err(Error::new(ErrorKind::InvalidInput, "Slice is empty"))
124 }
125 }
126}
127
128pub fn xor_encode(mut data: Vec<u8>, key: &[u8; 32]) -> Vec<u8> {
132 for i in 0..data.len() {
133 data[i] = data[i] ^ key[i % 32];
134 }
135
136 data
137}
138
139pub fn plain_text(data: Vec<u8>, _: &[u8; 32]) -> Vec<u8> {
143 data
144}
145
146fn get_checksum(data: &[u8]) -> [u8; CHECKSUM_LENGTH] {
148 let mut checksum = [0; CHECKSUM_LENGTH];
149
150 for i in 0..data.len() {
151 let mut sum = u16::from(checksum[i % CHECKSUM_LENGTH]) + u16::from(data[i]);
152
153 if sum > 255 {
154 sum %= 256;
155 }
156
157 checksum[i % CHECKSUM_LENGTH] = sum as u8;
158 }
159
160 checksum
161}
162
163pub struct Socket {
164 addr: SocketAddr,
165 socket: UdpSocket,
166 secrets: HashMap<SocketAddr, [u8; 32]>,
167 floating_connections: HashMap<SocketAddr, (ReusableSecret, mpsc::Sender<()>)>,
168
169 pub encoder: Box<dyn Fn(Vec<u8>, &[u8; 32]) -> Vec<u8>>,
171
172 pub decoder: Box<dyn Fn(Vec<u8>, &[u8; 32]) -> Vec<u8>>
174}
175
176impl Socket {
177 pub fn new(addr: SocketAddr) -> Result<Socket, Error> {
178 match UdpSocket::bind(addr) {
179 Ok(socket) => Ok(Socket {
180 addr,
181 socket,
182 secrets: HashMap::new(),
183 floating_connections: HashMap::new(),
184 encoder: Box::new(xor_encode),
185 decoder: Box::new(xor_encode)
186 }),
187 Err(err) => Err(err)
188 }
189 }
190
191 pub fn from_socket(socket: UdpSocket) -> Result<Socket, Error> {
192 match socket.local_addr() {
193 Ok(addr) => Ok(Socket {
194 socket,
195 addr,
196 secrets: HashMap::new(),
197 floating_connections: HashMap::new(),
198 encoder: Box::new(xor_encode),
199 decoder: Box::new(xor_encode)
200 }),
201 Err(err) => Err(err)
202 }
203 }
204
205 pub fn socket(&self) -> &UdpSocket {
206 &self.socket
207 }
208
209 pub fn addr(&self) -> SocketAddr {
210 self.addr
211 }
212
213 pub fn set_encoder<T: Fn(Vec<u8>, &[u8; 32]) -> Vec<u8> + 'static>(&mut self, encoder: T) {
215 self.encoder = Box::new(encoder);
216 }
217
218 pub fn set_decoder<T: Fn(Vec<u8>, &[u8; 32]) -> Vec<u8> + 'static>(&mut self, decoder: T) {
220 self.decoder = Box::new(decoder);
221 }
222
223 fn write(&self, addr: SocketAddr, packet: Packet) -> Result<usize, Error> {
224 self.socket.send_to(packet.to_bytes().as_slice(), addr)
225 }
226
227 fn read(&self) -> Result<(SocketAddr, Packet), Error> {
228 let mut buf = [0; 65536];
229
230 match self.socket.recv_from(&mut buf) {
231 Ok((size, from)) => {
232 match Packet::from_bytes(&buf[..size]) {
233 Ok(packet) => Ok((from, packet)),
234 Err(err) => Err(err)
235 }
236 },
237 Err(err) => Err(err)
238 }
239 }
240
241 pub fn generate_secret(&mut self, addr: SocketAddr) -> Result<Await<Duration>, Error> {
267 let (sender, receiver) = mpsc::channel::<()>();
268
269 self.floating_connections.insert(addr, (ReusableSecret::new(OsRng), sender));
270
271 match self.write(addr, Packet::KeyExchangeInit(PublicKey::from(&self.floating_connections.get(&addr).unwrap().0))) {
272 Ok(_) => {
273 let instant = Instant::now();
274
275 Ok(Await::new(move || {
276 receiver.recv();
277
278 instant.elapsed()
279 }))
280 },
281 Err(err) => Err(err)
282 }
283 }
284
285 pub fn shared_secret(&self, addr: SocketAddr) -> Option<&[u8; 32]> {
289 self.secrets.get(&addr)
290 }
291
292 pub fn send(&self, addr: SocketAddr, mut data: Vec<u8>) -> Result<usize, Error> {
317 match self.secrets.get(&addr) {
318 Some(secret) => {
319 data = [get_checksum(data.as_slice()).to_vec(), data].concat();
320 data = (self.encoder)(data, secret);
321
322 self.write(addr, Packet::Datagram(data))
323 },
324 None => Err(Error::new(ErrorKind::NotConnected, "Current socket doesn't have a shared secret with specified remote address"))
325 }
326 }
327
328 pub fn recv(&mut self) -> Option<(SocketAddr, Vec<u8>)> {
332 match self.read() {
333 Ok((from, packet)) => {
334 match packet {
335 Packet::KeyExchangeInit(public_key) => {
336 let secret = ReusableSecret::new(OsRng);
337
338 self.secrets.insert(from, *secret.diffie_hellman(&public_key).as_bytes());
339
340 self.write(from, Packet::KeyExchangeDone(PublicKey::from(&secret)));
341
342 None
343 }
344
345 Packet::KeyExchangeDone(public_key) => {
346 if let Some((secret, sender)) = self.floating_connections.get(&from) {
347 self.secrets.insert(from, *secret.diffie_hellman(&public_key).as_bytes());
348
349 sender.send(());
350
351 self.floating_connections.remove(&from);
352 }
353
354 None
355 }
356
357 Packet::Datagram(data) => {
358 if let Some(secret) = self.secrets.get(&from) {
360 let decoded = (self.decoder)(data.clone(), secret);
361
362 if get_checksum(&decoded[4..]) == &decoded[0..4] {
363 return Some((from, decoded[4..].to_vec()))
364 }
365 }
366
367 let mut found = None;
372
373 for (remote, shared) in &self.secrets {
374 let decoded = (self.decoder)(data.clone(), &shared);
375
376 if get_checksum(&decoded[4..]) == &decoded[0..4] {
377 found = Some((decoded[4..].to_vec(), remote.clone(), shared.clone()));
378
379 break;
380 }
381 }
382
383 if let Some((decoded, remote, shared)) = found {
384 self.secrets.remove(&remote);
385 self.secrets.insert(from, shared.clone());
386
387 return Some((from, decoded));
388 }
389
390 None
391 }
392 }
393 },
394 Err(_) => None
395 }
396 }
397}