1use std::collections::HashMap;
20use std::hash::Hash;
21use std::net::SocketAddr;
22use std::time::Instant;
23
24use bytes::Bytes;
25
26use crate::packet_utils::read_op_code;
27use crate::packets::RemapConnection;
28use crate::protocol::{DisconnectReason, OpCode};
29use crate::session::{
30 ApplicationParameters, SessionEvent, SessionMode, SessionParameters, SessionState, SoeSession,
31};
32
33pub trait SoeSocket {
43 fn local_addr(&self) -> std::io::Result<SocketAddr>;
45
46 fn session_count(&self) -> usize;
48
49 fn connect(&mut self, remote: SocketAddr);
52
53 fn enqueue_data(&mut self, remote: &SocketAddr, data: &[u8]) -> bool;
56
57 fn terminate(&mut self, remote: &SocketAddr, reason: DisconnectReason);
59}
60
61pub trait RemoteAddr: Clone + Eq + Hash {
67 fn same_host(&self, other: &Self) -> bool;
70}
71
72impl RemoteAddr for SocketAddr {
73 fn same_host(&self, other: &Self) -> bool {
74 self.ip() == other.ip()
75 }
76}
77
78#[derive(Debug, Clone, Default)]
80pub struct SocketConfig {
81 pub default_session_params: SessionParameters,
83 pub app_params: ApplicationParameters,
85 pub allow_port_remaps: bool,
87 pub base_rng_seed: u64,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub enum SocketEvent<A> {
95 SessionOpened {
97 remote: A,
99 },
100 DataReceived {
102 remote: A,
104 data: Bytes,
106 },
107 SessionClosed {
109 remote: A,
111 reason: DisconnectReason,
113 },
114}
115
116#[derive(Debug)]
118pub struct SoeMultiplexer<A: RemoteAddr> {
119 config: SocketConfig,
120 sessions: HashMap<A, SoeSession>,
121 outgoing: Vec<(A, Bytes)>,
122 events: Vec<SocketEvent<A>>,
123 next_seed: u64,
124}
125
126impl<A: RemoteAddr> SoeMultiplexer<A> {
127 pub fn new(config: SocketConfig) -> Self {
129 let next_seed = config.base_rng_seed;
130 Self {
131 config,
132 sessions: HashMap::new(),
133 outgoing: Vec::new(),
134 events: Vec::new(),
135 next_seed,
136 }
137 }
138
139 pub fn session_count(&self) -> usize {
141 self.sessions.len()
142 }
143
144 pub fn session(&self, remote: &A) -> Option<&SoeSession> {
146 self.sessions.get(remote)
147 }
148
149 pub fn take_outgoing(&mut self) -> Vec<(A, Bytes)> {
151 std::mem::take(&mut self.outgoing)
152 }
153
154 pub fn take_events(&mut self) -> Vec<SocketEvent<A>> {
156 std::mem::take(&mut self.events)
157 }
158
159 pub fn connect(&mut self, remote: A, now: Instant) {
161 self.create_session(remote.clone(), SessionMode::Client, now);
162 if let Some(session) = self.sessions.get_mut(&remote) {
163 session.send_session_request();
164 }
165 self.drain_session(&remote);
166 }
167
168 #[must_use = "a false return means the data was dropped because no running session exists for the address"]
171 pub fn enqueue_data(&mut self, remote: &A, data: &[u8]) -> bool {
172 let queued = match self.sessions.get_mut(remote) {
173 Some(session) => session.enqueue_data(data),
174 None => false,
175 };
176 self.drain_session(remote);
177 queued
178 }
179
180 pub fn terminate(&mut self, remote: &A, reason: DisconnectReason, now: Instant) {
182 if let Some(session) = self.sessions.get_mut(remote) {
183 session.terminate(reason, true, now);
184 }
185 self.drain_session(remote);
186 self.remove_if_terminated(remote);
187 }
188
189 pub fn process_incoming(&mut self, remote: A, datagram: Bytes, now: Instant) {
191 if !self.sessions.contains_key(&remote) {
192 match read_op_code(&datagram) {
193 Some(OpCode::SessionRequest) => {
194 self.create_session(remote.clone(), SessionMode::Server, now);
195 }
196 Some(OpCode::RemapConnection) => {
197 self.handle_remap(&remote, &datagram);
198 return;
199 }
200 _ => return,
202 }
203 }
204
205 if let Some(session) = self.sessions.get_mut(&remote) {
206 session.process_incoming(datagram, now);
207 }
208 self.drain_session(&remote);
209 self.remove_if_terminated(&remote);
210 }
211
212 pub fn run_tick(&mut self, now: Instant) {
216 let mut outgoing = std::mem::take(&mut self.outgoing);
221 let mut events = std::mem::take(&mut self.events);
222
223 self.sessions.retain(|remote, session| {
224 session.run_tick(now);
225 Self::drain_into(remote, session, &mut outgoing, &mut events);
226 session.state() != SessionState::Terminated
227 });
228
229 self.outgoing = outgoing;
230 self.events = events;
231 }
232
233 pub fn drive<T>(&mut self, transport: &mut T, now: Instant) -> std::io::Result<()>
239 where
240 T: UdpTransport<Addr = A>,
241 {
242 let mut buf = [0u8; 2048];
243 while let Some((len, from)) = transport.try_recv(&mut buf)? {
244 self.process_incoming(from, Bytes::copy_from_slice(&buf[..len]), now);
245 }
246
247 self.run_tick(now);
248
249 for (addr, datagram) in self.take_outgoing() {
250 transport.send_to(&datagram, &addr)?;
251 }
252 Ok(())
253 }
254
255 fn create_session(&mut self, remote: A, mode: SessionMode, now: Instant) {
256 let seed = self.next_seed;
257 self.next_seed = self.next_seed.wrapping_add(1);
258
259 let session = SoeSession::new(
260 mode,
261 self.config.default_session_params.clone(),
262 self.config.app_params.clone(),
263 seed,
264 now,
265 );
266 self.sessions.insert(remote, session);
267 }
268
269 fn handle_remap(&mut self, from: &A, datagram: &[u8]) {
270 if !self.config.allow_port_remaps {
271 return;
272 }
273
274 let remap = match RemapConnection::deserialize(datagram, true) {
275 Ok(remap) => remap,
276 Err(_) => return,
277 };
278
279 let old_key = self.sessions.iter().find_map(|(key, session)| {
280 (session.session_id() == remap.session_id && session.crc_seed() == remap.crc_seed)
281 .then(|| key.clone())
282 });
283 let Some(old_key) = old_key else { return };
284
285 if &old_key == from || !old_key.same_host(from) {
287 return;
288 }
289
290 if let Some(session) = self.sessions.remove(&old_key) {
291 self.sessions.insert(from.clone(), session);
292 }
293 }
294
295 fn drain_session(&mut self, remote: &A) {
296 if let Some(session) = self.sessions.get_mut(remote) {
297 Self::drain_into(remote, session, &mut self.outgoing, &mut self.events);
298 }
299 }
300
301 fn drain_into(
310 remote: &A,
311 session: &mut SoeSession,
312 outgoing: &mut Vec<(A, Bytes)>,
313 events: &mut Vec<SocketEvent<A>>,
314 ) {
315 for datagram in session.take_outgoing() {
316 outgoing.push((remote.clone(), datagram));
317 }
318
319 let session_events = session.take_events();
320
321 for event in &session_events {
322 if matches!(event, SessionEvent::Opened) {
323 events.push(SocketEvent::SessionOpened {
324 remote: remote.clone(),
325 });
326 }
327 }
328
329 for data in session.take_received() {
330 events.push(SocketEvent::DataReceived {
331 remote: remote.clone(),
332 data,
333 });
334 }
335
336 for event in session_events {
337 if let SessionEvent::Closed(reason) = event {
338 events.push(SocketEvent::SessionClosed {
339 remote: remote.clone(),
340 reason,
341 });
342 }
343 }
344 }
345
346 fn remove_if_terminated(&mut self, remote: &A) {
347 if let Some(session) = self.sessions.get(remote)
348 && session.state() == SessionState::Terminated
349 {
350 self.sessions.remove(remote);
351 }
352 }
353}
354
355pub trait UdpTransport {
360 type Addr: RemoteAddr;
362
363 fn try_recv(&mut self, buf: &mut [u8]) -> std::io::Result<Option<(usize, Self::Addr)>>;
366
367 fn send_to(&mut self, buf: &[u8], addr: &Self::Addr) -> std::io::Result<usize>;
369}
370
371impl UdpTransport for std::net::UdpSocket {
372 type Addr = std::net::SocketAddr;
373
374 fn try_recv(&mut self, buf: &mut [u8]) -> std::io::Result<Option<(usize, Self::Addr)>> {
375 match self.recv_from(buf) {
376 Ok((len, from)) => Ok(Some((len, from))),
377 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
378 Err(e) => Err(e),
379 }
380 }
381
382 fn send_to(&mut self, buf: &[u8], addr: &Self::Addr) -> std::io::Result<usize> {
383 std::net::UdpSocket::send_to(self, buf, addr)
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::rc4::Rc4KeyState;
391 use std::net::SocketAddr;
392
393 const CLIENT: &str = "127.0.0.1:4001";
394 const SERVER: &str = "127.0.0.1:4002";
395
396 fn addr(s: &str) -> SocketAddr {
397 s.parse().unwrap()
398 }
399
400 fn config(protocol: &str, seed: u64) -> SocketConfig {
401 let mut params = SessionParameters {
402 application_protocol: protocol.to_owned(),
403 ..SessionParameters::default()
404 };
405 params.heartbeat_after = std::time::Duration::ZERO;
406 params.inactivity_timeout = std::time::Duration::ZERO;
407 SocketConfig {
408 default_session_params: params,
409 app_params: ApplicationParameters::default(),
410 allow_port_remaps: false,
411 base_rng_seed: seed,
412 }
413 }
414
415 fn pump(client: &mut SoeMultiplexer<SocketAddr>, server: &mut SoeMultiplexer<SocketAddr>) {
418 let now = Instant::now();
419 for _ in 0..64 {
420 client.run_tick(now);
422 server.run_tick(now);
423
424 let from_client = client.take_outgoing();
425 let from_server = server.take_outgoing();
426 if from_client.is_empty() && from_server.is_empty() {
427 break;
428 }
429 for (_dest, dg) in from_client {
430 server.process_incoming(addr(CLIENT), dg, now);
431 }
432 for (_dest, dg) in from_server {
433 client.process_incoming(addr(SERVER), dg, now);
434 }
435 }
436 }
437
438 #[test]
439 fn establishes_session_and_emits_opened() {
440 let now = Instant::now();
441 let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
442 let mut server = SoeMultiplexer::new(config("TestProtocol", 2));
443
444 client.connect(addr(SERVER), now);
445 pump(&mut client, &mut server);
446
447 assert_eq!(client.session_count(), 1);
448 assert_eq!(server.session_count(), 1);
449 assert!(client.take_events().iter().any(|e| matches!(
450 e,
451 SocketEvent::SessionOpened { remote } if *remote == addr(SERVER)
452 )));
453
454 assert!(client.enqueue_data(&addr(SERVER), b"hi"));
457 pump(&mut client, &mut server);
458 assert!(server.take_events().iter().any(|e| matches!(
459 e,
460 SocketEvent::SessionOpened { remote } if *remote == addr(CLIENT)
461 )));
462 }
463
464 #[test]
465 fn routes_data_between_peers() {
466 let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
467 let mut server = SoeMultiplexer::new(config("TestProtocol", 2));
468
469 client.connect(addr(SERVER), Instant::now());
470 pump(&mut client, &mut server);
471
472 assert!(client.enqueue_data(&addr(SERVER), b"ping"));
473 pump(&mut client, &mut server);
474 assert!(server.take_events().iter().any(|e| matches!(
475 e,
476 SocketEvent::DataReceived { remote, data } if *remote == addr(CLIENT) && data == "ping"
477 )));
478
479 assert!(server.enqueue_data(&addr(CLIENT), b"pong"));
480 pump(&mut client, &mut server);
481 assert!(client.take_events().iter().any(|e| matches!(
482 e,
483 SocketEvent::DataReceived { remote, data } if *remote == addr(SERVER) && data == "pong"
484 )));
485 }
486
487 #[test]
488 fn encrypted_data_routes_between_peers() {
489 let key = Rc4KeyState::new(&[1, 2, 3, 4, 5]);
490 let mut client_cfg = config("TestProtocol", 1);
491 let mut server_cfg = config("TestProtocol", 2);
492 client_cfg.app_params.encryption_key_state = Some(key.clone());
493 server_cfg.app_params.encryption_key_state = Some(key);
494
495 let mut client = SoeMultiplexer::new(client_cfg);
496 let mut server = SoeMultiplexer::new(server_cfg);
497
498 client.connect(addr(SERVER), Instant::now());
499 pump(&mut client, &mut server);
500
501 let payload = vec![0u8; 200];
502 assert!(client.enqueue_data(&addr(SERVER), &payload));
503 pump(&mut client, &mut server);
504 assert!(server.take_events().iter().any(|e| matches!(
505 e,
506 SocketEvent::DataReceived { remote, data }
507 if *remote == addr(CLIENT) && data.as_ref() == payload.as_slice()
508 )));
509 }
510
511 #[test]
512 fn terminate_notifies_remote_and_removes_session() {
513 let now = Instant::now();
514 let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
515 let mut server = SoeMultiplexer::new(config("TestProtocol", 2));
516
517 client.connect(addr(SERVER), now);
518 pump(&mut client, &mut server);
519 client.take_events();
521 server.take_events();
522
523 client.terminate(&addr(SERVER), DisconnectReason::Application, now);
524 pump(&mut client, &mut server);
525
526 assert_eq!(client.session_count(), 0);
527 assert_eq!(server.session_count(), 0);
528 assert!(server.take_events().iter().any(|e| matches!(
529 e,
530 SocketEvent::SessionClosed { remote, reason }
531 if *remote == addr(CLIENT) && *reason == DisconnectReason::Application
532 )));
533 }
534
535 #[test]
536 fn ignores_stray_datagram_without_session() {
537 let now = Instant::now();
538 let mut server = SoeMultiplexer::<SocketAddr>::new(config("TestProtocol", 1));
539
540 server.process_incoming(addr(CLIENT), Bytes::from_static(&[0x00, 0x09, 0x00]), now);
543
544 assert_eq!(server.session_count(), 0);
545 assert!(server.take_outgoing().is_empty());
546 assert!(server.take_events().is_empty());
547 }
548
549 #[test]
550 fn remaps_port_for_matching_session() {
551 let now = Instant::now();
552 let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
553 let mut server_cfg = config("TestProtocol", 2);
554 server_cfg.allow_port_remaps = true;
555 let mut server = SoeMultiplexer::new(server_cfg);
556
557 client.connect(addr(SERVER), now);
558 pump(&mut client, &mut server);
559
560 let session = server.session(&addr(CLIENT)).expect("server session");
561 let remap = RemapConnection {
562 session_id: session.session_id(),
563 crc_seed: session.crc_seed(),
564 };
565 let mut buf = [0u8; RemapConnection::SIZE];
566 let n = remap.serialize(&mut buf).unwrap();
567
568 let new_client = addr("127.0.0.1:4099");
569 server.process_incoming(new_client, Bytes::copy_from_slice(&buf[..n]), now);
570
571 assert!(server.session(&addr(CLIENT)).is_none());
572 assert!(server.session(&new_client).is_some());
573 }
574
575 #[test]
576 fn rejects_remap_from_different_host() {
577 let now = Instant::now();
578 let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
579 let mut server_cfg = config("TestProtocol", 2);
580 server_cfg.allow_port_remaps = true;
581 let mut server = SoeMultiplexer::new(server_cfg);
582
583 client.connect(addr(SERVER), now);
584 pump(&mut client, &mut server);
585
586 let session = server.session(&addr(CLIENT)).expect("server session");
587 let remap = RemapConnection {
588 session_id: session.session_id(),
589 crc_seed: session.crc_seed(),
590 };
591 let mut buf = [0u8; RemapConnection::SIZE];
592 let n = remap.serialize(&mut buf).unwrap();
593
594 let attacker = addr("10.0.0.1:5000");
597 server.process_incoming(attacker, Bytes::copy_from_slice(&buf[..n]), now);
598
599 assert!(server.session(&addr(CLIENT)).is_some());
600 assert!(server.session(&attacker).is_none());
601 }
602}