turn_server/turn/sessions.rs
1use super::Observer;
2use crate::stun::util::long_term_credential_digest;
3
4use std::{
5 hash::Hash,
6 net::SocketAddr,
7 ops::{Deref, DerefMut, Range},
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12 thread::{self, sleep},
13 time::Duration,
14};
15
16use ahash::{HashMap, HashMapExt};
17use parking_lot::{Mutex, RwLock, RwLockReadGuard};
18use rand::{Rng, distributions::Alphanumeric, thread_rng};
19
20/// Authentication information for the session.
21///
22/// Digest data is data that summarises usernames and passwords by means of
23/// long-term authentication.
24#[derive(Debug, Clone)]
25pub struct Auth {
26 pub username: String,
27 pub integrity: [u8; 16],
28}
29
30/// Assignment information for the session.
31///
32/// Sessions are all bound to only one port and one channel.
33#[derive(Debug, Clone)]
34pub struct Allocate {
35 pub port: Option<u16>,
36 pub channels: Vec<u16>,
37}
38
39/// turn session information.
40///
41/// A user can have many sessions.
42///
43/// The default survival time for a session is 600 seconds.
44#[derive(Debug, Clone)]
45pub struct Session {
46 pub auth: Auth,
47 pub allocate: Allocate,
48 pub permissions: Vec<u16>,
49 pub expires: u64,
50}
51
52/// The identifier of the session or addr.
53///
54/// Each session needs to be identified by a combination of three pieces of
55/// information: the addr address, and the transport protocol.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub struct SessionAddr {
58 pub address: SocketAddr,
59 pub interface: SocketAddr,
60}
61
62/// The addr used to record the current session.
63///
64/// This is used when forwarding data.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub struct Endpoint {
67 pub address: SocketAddr,
68 pub endpoint: SocketAddr,
69}
70
71/// A specially optimised timer.
72///
73/// This timer does not stack automatically and needs to be stacked externally
74/// and manually.
75///
76/// ```
77/// use turn_server::turn::sessions::Timer;
78///
79/// let timer = Timer::default();
80///
81/// assert_eq!(timer.get(), 0);
82/// assert_eq!(timer.add(), 1);
83/// assert_eq!(timer.get(), 1);
84/// ```
85#[derive(Default)]
86pub struct Timer(AtomicU64);
87
88impl Timer {
89 pub fn get(&self) -> u64 {
90 self.0.load(Ordering::Relaxed)
91 }
92
93 pub fn add(&self) -> u64 {
94 self.0.fetch_add(1, Ordering::Relaxed) + 1
95 }
96}
97
98#[derive(Default)]
99pub struct State {
100 sessions: RwLock<Table<SessionAddr, Session>>,
101 port_allocate_pool: Mutex<PortAllocatePools>,
102 // Records the sessions corresponding to each assigned port, which will be needed when looking
103 // up sessions assigned to this port based on the port number.
104 port_mapping_table: RwLock<Table</* port */ u16, SessionAddr>>,
105 // Records the nonce value for each network connection, which is independent of the session
106 // because it can exist before it is authenticated.
107 address_nonce_tanle: RwLock<Table<SessionAddr, (String, /* expires */ u64)>>,
108 // Stores the address to which the session should be forwarded when it sends indication to a
109 // port. This is written when permissions are created to allow a certain address to be
110 // forwarded to the current session.
111 port_relay_table: RwLock<Table<SessionAddr, HashMap</* port */ u16, Endpoint>>>,
112 // Indicates to which session the data sent by a session to a channel should be forwarded.
113 channel_relay_table: RwLock<Table<SessionAddr, HashMap</* channel */ u16, Endpoint>>>,
114}
115
116pub struct Sessions<T> {
117 timer: Timer,
118 state: State,
119 observer: T,
120}
121
122impl<T: Observer + 'static> Sessions<T> {
123 pub fn new(observer: T) -> Arc<Self> {
124 let this = Arc::new(Self {
125 state: State::default(),
126 timer: Timer::default(),
127 observer,
128 });
129
130 // This is a background thread that silently handles expiring sessions and
131 // cleans up session information when it expires.
132 let this_ = Arc::downgrade(&this);
133 thread::spawn(move || {
134 let mut address = Vec::with_capacity(255);
135
136 while let Some(this) = this_.upgrade() {
137 // The timer advances one second and gets the current time offset.
138 let now = this.timer.add();
139
140 // This is the part that deletes the session information.
141 {
142 // Finds sessions that have expired.
143 {
144 this.state
145 .sessions
146 .read()
147 .iter()
148 .filter(|(_, v)| v.expires <= now)
149 .for_each(|(k, _)| address.push(*k));
150 }
151
152 // Delete the expired sessions.
153 if !address.is_empty() {
154 this.remove_session(&address);
155 address.clear();
156 }
157 }
158
159 // Because nonce does not follow session creation, nonce is created for each
160 // addr, so nonce deletion is handled independently.
161 {
162 this.state
163 .address_nonce_tanle
164 .read()
165 .iter()
166 .filter(|(_, v)| v.1 <= now)
167 .for_each(|(k, _)| address.push(*k));
168
169 if !address.is_empty() {
170 this.remove_nonce(&address);
171 address.clear();
172 }
173 }
174
175 // Fixing a second tick.
176 sleep(Duration::from_secs(1));
177 }
178 });
179
180 this
181 }
182
183 fn remove_session(&self, addrs: &[SessionAddr]) {
184 let mut sessions = self.state.sessions.write();
185 let mut port_allocate_pool = self.state.port_allocate_pool.lock();
186 let mut port_mapping_table = self.state.port_mapping_table.write();
187 let mut port_relay_table = self.state.port_relay_table.write();
188 let mut channel_relay_table = self.state.channel_relay_table.write();
189
190 addrs.iter().for_each(|k| {
191 port_relay_table.remove(k);
192 channel_relay_table.remove(k);
193
194 if let Some(session) = sessions.remove(k) {
195 // Removes the session-bound port from the port binding table and
196 // releases the port back into the allocation pool.
197 if let Some(port) = session.allocate.port {
198 port_mapping_table.remove(&port);
199 port_allocate_pool.restore(port);
200 }
201
202 // Notifies that the external session has been closed.
203 self.observer.closed(k, &session.auth.username);
204 }
205 });
206 }
207
208 fn remove_nonce(&self, addrs: &[SessionAddr]) {
209 let mut address_nonce_tanle = self.state.address_nonce_tanle.write();
210
211 addrs.iter().for_each(|k| {
212 address_nonce_tanle.remove(k);
213 });
214 }
215
216 /// Get session for addr.
217 ///
218 /// # Test
219 ///
220 /// ```
221 /// use turn_server::turn::*;
222 ///
223 /// #[derive(Clone)]
224 /// struct ObserverTest;
225 ///
226 /// impl Observer for ObserverTest {
227 /// fn get_password(&self, username: &str) -> Option<String> {
228 /// if username == "test" {
229 /// Some("test".to_string())
230 /// } else {
231 /// None
232 /// }
233 /// }
234 /// }
235 ///
236 /// let addr = SessionAddr {
237 /// address: "127.0.0.1:8080".parse().unwrap(),
238 /// interface: "127.0.0.1:3478".parse().unwrap(),
239 /// };
240 ///
241 /// let digest = [
242 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
243 /// 239,
244 /// ];
245 ///
246 /// let sessions = Sessions::new(ObserverTest);
247 ///
248 /// assert!(sessions.get_session(&addr).get_ref().is_none());
249 ///
250 /// sessions.get_integrity(&addr, "test", "test");
251 ///
252 /// let lock = sessions.get_session(&addr);
253 /// let session = lock.get_ref().unwrap();
254 /// assert_eq!(session.auth.username, "test");
255 /// assert_eq!(session.allocate.port, None);
256 /// assert_eq!(session.allocate.channels.len(), 0);
257 /// ```
258 pub fn get_session<'a, 'b>(
259 &'a self,
260 key: &'b SessionAddr,
261 ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, Session>> {
262 ReadLock {
263 lock: self.state.sessions.read(),
264 key,
265 }
266 }
267
268 /// Get nonce for addr.
269 ///
270 /// # Test
271 ///
272 /// ```
273 /// use turn_server::turn::*;
274 ///
275 /// #[derive(Clone)]
276 /// struct ObserverTest;
277 ///
278 /// impl Observer for ObserverTest {}
279 ///
280 /// let addr = SessionAddr {
281 /// address: "127.0.0.1:8080".parse().unwrap(),
282 /// interface: "127.0.0.1:3478".parse().unwrap(),
283 /// };
284 ///
285 /// let sessions = Sessions::new(ObserverTest);
286 ///
287 /// let a = sessions.get_nonce(&addr).get_ref().unwrap().clone();
288 /// assert!(a.0.len() == 16);
289 /// assert!(a.1 == 600 || a.1 == 601 || a.1 == 602);
290 ///
291 /// let b = sessions.get_nonce(&addr).get_ref().unwrap().clone();
292 /// assert_eq!(a.0, b.0);
293 /// assert!(b.1 == 600 || b.1 == 601 || b.1 == 602);
294 /// ```
295 pub fn get_nonce<'a, 'b>(
296 &'a self,
297 key: &'b SessionAddr,
298 ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, (String, u64)>> {
299 // If no nonce is created, create a new one.
300 {
301 if !self.state.address_nonce_tanle.read().contains_key(key) {
302 self.state.address_nonce_tanle.write().insert(
303 *key,
304 (
305 // A random string of length 16.
306 {
307 let mut rng = thread_rng();
308 std::iter::repeat(())
309 .map(|_| rng.sample(Alphanumeric) as char)
310 .take(16)
311 .collect::<String>()
312 .to_lowercase()
313 },
314 // Current time stacks for 600 seconds.
315 self.timer.get() + 600,
316 ),
317 );
318 }
319 }
320
321 ReadLock {
322 lock: self.state.address_nonce_tanle.read(),
323 key,
324 }
325 }
326
327 /// Get digest for addr.
328 ///
329 /// # Test
330 ///
331 /// ```
332 /// use turn_server::turn::*;
333 ///
334 /// #[derive(Clone)]
335 /// struct ObserverTest;
336 ///
337 /// impl Observer for ObserverTest {
338 /// fn get_password(&self, username: &str) -> Option<String> {
339 /// if username == "test" {
340 /// Some("test".to_string())
341 /// } else {
342 /// None
343 /// }
344 /// }
345 /// }
346 ///
347 /// let addr = SessionAddr {
348 /// address: "127.0.0.1:8080".parse().unwrap(),
349 /// interface: "127.0.0.1:3478".parse().unwrap(),
350 /// };
351 ///
352 /// let digest = [
353 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
354 /// 239,
355 /// ];
356 ///
357 /// let sessions = Sessions::new(ObserverTest);
358 ///
359 /// assert_eq!(sessions.get_integrity(&addr, "test1", "test"), None);
360 ///
361 /// assert_eq!(sessions.get_integrity(&addr, "test", "test"), Some(digest));
362 ///
363 /// assert_eq!(sessions.get_integrity(&addr, "test", "test"), Some(digest));
364 /// ```
365 pub fn get_integrity(&self, addr: &SessionAddr, username: &str, realm: &str) -> Option<[u8; 16]> {
366 // Already authenticated, get the cached digest directly.
367 {
368 if let Some(it) = self.state.sessions.read().get(addr) {
369 return Some(it.auth.integrity);
370 }
371 }
372
373 // Get the current user's password from an external observer and create a
374 // digest.
375 let password = self.observer.get_password(username)?;
376 let integrity = long_term_credential_digest(&username, &password, realm);
377
378 // Record a new session.
379 {
380 self.state.sessions.write().insert(
381 *addr,
382 Session {
383 permissions: Vec::with_capacity(10),
384 expires: self.timer.get() + 600,
385 auth: Auth {
386 username: username.to_string(),
387 integrity,
388 },
389 allocate: Allocate {
390 channels: Vec::with_capacity(10),
391 port: None,
392 },
393 },
394 );
395 }
396
397 Some(integrity)
398 }
399
400 pub fn allocated(&self) -> usize {
401 self.state.port_allocate_pool.lock().len()
402 }
403
404 /// Assign a port number to the session.
405 ///
406 /// # Test
407 ///
408 /// ```
409 /// use turn_server::turn::*;
410 ///
411 /// #[derive(Clone)]
412 /// struct ObserverTest;
413 ///
414 /// impl Observer for ObserverTest {
415 /// fn get_password(&self, username: &str) -> Option<String> {
416 /// if username == "test" {
417 /// Some("test".to_string())
418 /// } else {
419 /// None
420 /// }
421 /// }
422 /// }
423 ///
424 /// let addr = SessionAddr {
425 /// address: "127.0.0.1:8080".parse().unwrap(),
426 /// interface: "127.0.0.1:3478".parse().unwrap(),
427 /// };
428 ///
429 /// let digest = [
430 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
431 /// 239,
432 /// ];
433 ///
434 /// let sessions = Sessions::new(ObserverTest);
435 ///
436 /// sessions.get_integrity(&addr, "test", "test");
437 ///
438 /// {
439 /// let lock = sessions.get_session(&addr);
440 /// let session = lock.get_ref().unwrap();
441 /// assert_eq!(session.auth.username, "test");
442 /// assert_eq!(session.allocate.port, None);
443 /// assert_eq!(session.allocate.channels.len(), 0);
444 /// }
445 ///
446 /// let port = sessions.allocate(&addr).unwrap();
447 /// {
448 /// let lock = sessions.get_session(&addr);
449 /// let session = lock.get_ref().unwrap();
450 /// assert_eq!(session.auth.username, "test");
451 /// assert_eq!(session.allocate.port, Some(port));
452 /// assert_eq!(session.allocate.channels.len(), 0);
453 /// }
454 ///
455 /// assert_eq!(sessions.allocate(&addr), Some(port));
456 /// ```
457 pub fn allocate(&self, addr: &SessionAddr) -> Option<u16> {
458 let mut lock = self.state.sessions.write();
459 let session = lock.get_mut(addr)?;
460
461 // If the port has already been allocated, re-allocation is not allowed.
462 if let Some(port) = session.allocate.port {
463 return Some(port);
464 }
465
466 // Records the port assigned to the current session and resets the alive time.
467 let port = self.state.port_allocate_pool.lock().alloc(None)?;
468 session.expires = self.timer.get() + 600;
469 session.allocate.port = Some(port);
470
471 // Write the allocation port binding table.
472 self.state.port_mapping_table.write().insert(port, *addr);
473 Some(port)
474 }
475
476 /// Create permission for session.
477 ///
478 /// # Test
479 ///
480 /// ```
481 /// use turn_server::turn::*;
482 ///
483 /// #[derive(Clone)]
484 /// struct ObserverTest;
485 ///
486 /// impl Observer for ObserverTest {
487 /// fn get_password(&self, username: &str) -> Option<String> {
488 /// if username == "test" {
489 /// Some("test".to_string())
490 /// } else {
491 /// None
492 /// }
493 /// }
494 /// }
495 ///
496 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
497 /// let addr = SessionAddr {
498 /// address: "127.0.0.1:8080".parse().unwrap(),
499 /// interface: "127.0.0.1:3478".parse().unwrap(),
500 /// };
501 ///
502 /// let peer_addr = SessionAddr {
503 /// address: "127.0.0.1:8081".parse().unwrap(),
504 /// interface: "127.0.0.1:3478".parse().unwrap(),
505 /// };
506 ///
507 /// let digest = [
508 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
509 /// 239,
510 /// ];
511 ///
512 /// let sessions = Sessions::new(ObserverTest);
513 ///
514 /// sessions.get_integrity(&addr, "test", "test");
515 /// sessions.get_integrity(&peer_addr, "test", "test");
516 ///
517 /// let port = sessions.allocate(&addr).unwrap();
518 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
519 ///
520 /// assert!(!sessions.create_permission(&addr, &endpoint, &[port]));
521 /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
522 ///
523 /// assert!(!sessions.create_permission(&peer_addr, &endpoint, &[peer_port]));
524 /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
525 /// ```
526 pub fn create_permission(&self, addr: &SessionAddr, endpoint: &SocketAddr, ports: &[u16]) -> bool {
527 let mut sessions = self.state.sessions.write();
528 let mut port_relay_table = self.state.port_relay_table.write();
529 let port_mapping_table = self.state.port_mapping_table.read();
530
531 // Finds information about the current session.
532 let session = if let Some(it) = sessions.get_mut(addr) {
533 it
534 } else {
535 return false;
536 };
537
538 // The port number assigned to the current session.
539 let local_port = if let Some(it) = session.allocate.port {
540 it
541 } else {
542 return false;
543 };
544
545 // You cannot create permissions for yourself.
546 if ports.contains(&local_port) {
547 return false;
548 }
549
550 // Each peer port must be present.
551 let mut peers = Vec::with_capacity(15);
552 for port in ports {
553 if let Some(it) = port_mapping_table.get(&port) {
554 peers.push((it, *port));
555 } else {
556 return false;
557 }
558 }
559
560 // Create a port forwarding mapping relationship for each peer session.
561 for (peer, port) in peers {
562 port_relay_table
563 .entry(*peer)
564 .or_insert_with(|| HashMap::with_capacity(20))
565 .insert(
566 local_port,
567 Endpoint {
568 address: addr.address,
569 endpoint: *endpoint,
570 },
571 );
572
573 // Do not store the same peer ports to the permission list over and over again.
574 if !session.permissions.contains(&port) {
575 session.permissions.push(port);
576 }
577 }
578
579 true
580 }
581
582 /// Binding a channel to the session.
583 ///
584 /// # Test
585 ///
586 /// ```
587 /// use turn_server::turn::*;
588 ///
589 /// #[derive(Clone)]
590 /// struct ObserverTest;
591 ///
592 /// impl Observer for ObserverTest {
593 /// fn get_password(&self, username: &str) -> Option<String> {
594 /// if username == "test" {
595 /// Some("test".to_string())
596 /// } else {
597 /// None
598 /// }
599 /// }
600 /// }
601 ///
602 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
603 /// let addr = SessionAddr {
604 /// address: "127.0.0.1:8080".parse().unwrap(),
605 /// interface: "127.0.0.1:3478".parse().unwrap(),
606 /// };
607 ///
608 /// let peer_addr = SessionAddr {
609 /// address: "127.0.0.1:8081".parse().unwrap(),
610 /// interface: "127.0.0.1:3478".parse().unwrap(),
611 /// };
612 ///
613 /// let digest = [
614 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
615 /// 239,
616 /// ];
617 ///
618 /// let sessions = Sessions::new(ObserverTest);
619 ///
620 /// sessions.get_integrity(&addr, "test", "test");
621 /// sessions.get_integrity(&peer_addr, "test", "test");
622 ///
623 /// let port = sessions.allocate(&addr).unwrap();
624 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
625 /// assert_eq!(
626 /// sessions
627 /// .get_session(&addr)
628 /// .get_ref()
629 /// .unwrap()
630 /// .allocate
631 /// .channels
632 /// .len(),
633 /// 0
634 /// );
635 ///
636 /// assert_eq!(
637 /// sessions
638 /// .get_session(&peer_addr)
639 /// .get_ref()
640 /// .unwrap()
641 /// .allocate
642 /// .channels
643 /// .len(),
644 /// 0
645 /// );
646 ///
647 /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
648 /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
649 /// assert_eq!(
650 /// sessions
651 /// .get_session(&addr)
652 /// .get_ref()
653 /// .unwrap()
654 /// .allocate
655 /// .channels,
656 /// vec![0x4000]
657 /// );
658 ///
659 /// assert_eq!(
660 /// sessions
661 /// .get_session(&peer_addr)
662 /// .get_ref()
663 /// .unwrap()
664 /// .allocate
665 /// .channels,
666 /// vec![0x4000]
667 /// );
668 /// ```
669 pub fn bind_channel(&self, addr: &SessionAddr, endpoint: &SocketAddr, port: u16, channel: u16) -> bool {
670 // Finds the address of the bound opposing port.
671 let peer = if let Some(it) = self.state.port_mapping_table.read().get(&port) {
672 *it
673 } else {
674 return false;
675 };
676
677 // Records the channel used for the current session.
678 {
679 let mut lock = self.state.sessions.write();
680 let session = if let Some(it) = lock.get_mut(addr) {
681 it
682 } else {
683 return false;
684 };
685
686 if !session.allocate.channels.contains(&channel) {
687 session.allocate.channels.push(channel);
688 }
689 }
690
691 // Binding ports also creates permissions.
692 if !self.create_permission(addr, endpoint, &[port]) {
693 return false;
694 }
695
696 // Create channel forwarding mapping relationships for peers.
697 self.state
698 .channel_relay_table
699 .write()
700 .entry(peer)
701 .or_insert_with(|| HashMap::with_capacity(10))
702 .insert(
703 channel,
704 Endpoint {
705 address: addr.address,
706 endpoint: *endpoint,
707 },
708 );
709
710 true
711 }
712
713 /// Gets the peer of the current session bound channel.
714 ///
715 /// # Test
716 ///
717 /// ```
718 /// use turn_server::turn::*;
719 ///
720 /// #[derive(Clone)]
721 /// struct ObserverTest;
722 ///
723 /// impl Observer for ObserverTest {
724 /// fn get_password(&self, username: &str) -> Option<String> {
725 /// if username == "test" {
726 /// Some("test".to_string())
727 /// } else {
728 /// None
729 /// }
730 /// }
731 /// }
732 ///
733 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
734 /// let addr = SessionAddr {
735 /// address: "127.0.0.1:8080".parse().unwrap(),
736 /// interface: "127.0.0.1:3478".parse().unwrap(),
737 /// };
738 ///
739 /// let peer_addr = SessionAddr {
740 /// address: "127.0.0.1:8081".parse().unwrap(),
741 /// interface: "127.0.0.1:3478".parse().unwrap(),
742 /// };
743 ///
744 /// let digest = [
745 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
746 /// 239,
747 /// ];
748 ///
749 /// let sessions = Sessions::new(ObserverTest);
750 ///
751 /// sessions.get_integrity(&addr, "test", "test");
752 /// sessions.get_integrity(&peer_addr, "test", "test");
753 ///
754 /// let port = sessions.allocate(&addr).unwrap();
755 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
756 ///
757 /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
758 /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
759 /// assert_eq!(
760 /// sessions
761 /// .get_channel_relay_address(&addr, 0x4000)
762 /// .unwrap()
763 /// .endpoint,
764 /// endpoint
765 /// );
766 ///
767 /// assert_eq!(
768 /// sessions
769 /// .get_channel_relay_address(&peer_addr, 0x4000)
770 /// .unwrap()
771 /// .endpoint,
772 /// endpoint
773 /// );
774 /// ```
775 pub fn get_channel_relay_address(&self, addr: &SessionAddr, channel: u16) -> Option<Endpoint> {
776 self.state.channel_relay_table.read().get(&addr)?.get(&channel).copied()
777 }
778
779 /// Get the address of the port binding.
780 ///
781 /// # Test
782 ///
783 /// ```
784 /// use turn_server::turn::*;
785 ///
786 /// #[derive(Clone)]
787 /// struct ObserverTest;
788 ///
789 /// impl Observer for ObserverTest {
790 /// fn get_password(&self, username: &str) -> Option<String> {
791 /// if username == "test" {
792 /// Some("test".to_string())
793 /// } else {
794 /// None
795 /// }
796 /// }
797 /// }
798 ///
799 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
800 /// let addr = SessionAddr {
801 /// address: "127.0.0.1:8080".parse().unwrap(),
802 /// interface: "127.0.0.1:3478".parse().unwrap(),
803 /// };
804 ///
805 /// let peer_addr = SessionAddr {
806 /// address: "127.0.0.1:8081".parse().unwrap(),
807 /// interface: "127.0.0.1:3478".parse().unwrap(),
808 /// };
809 ///
810 /// let digest = [
811 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
812 /// 239,
813 /// ];
814 ///
815 /// let sessions = Sessions::new(ObserverTest);
816 ///
817 /// sessions.get_integrity(&addr, "test", "test");
818 /// sessions.get_integrity(&peer_addr, "test", "test");
819 ///
820 /// let port = sessions.allocate(&addr).unwrap();
821 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
822 ///
823 /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
824 /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
825 ///
826 /// assert_eq!(
827 /// sessions
828 /// .get_relay_address(&addr, peer_port)
829 /// .unwrap()
830 /// .endpoint,
831 /// endpoint
832 /// );
833 ///
834 /// assert_eq!(
835 /// sessions
836 /// .get_relay_address(&peer_addr, port)
837 /// .unwrap()
838 /// .endpoint,
839 /// endpoint
840 /// );
841 /// ```
842 pub fn get_relay_address(&self, addr: &SessionAddr, port: u16) -> Option<Endpoint> {
843 self.state.port_relay_table.read().get(&addr)?.get(&port).copied()
844 }
845
846 /// Refresh the session for addr.
847 ///
848 /// # Test
849 ///
850 /// ```
851 /// use turn_server::turn::*;
852 ///
853 /// #[derive(Clone)]
854 /// struct ObserverTest;
855 ///
856 /// impl Observer for ObserverTest {
857 /// fn get_password(&self, username: &str) -> Option<String> {
858 /// if username == "test" {
859 /// Some("test".to_string())
860 /// } else {
861 /// None
862 /// }
863 /// }
864 /// }
865 ///
866 /// let addr = SessionAddr {
867 /// address: "127.0.0.1:8080".parse().unwrap(),
868 /// interface: "127.0.0.1:3478".parse().unwrap(),
869 /// };
870 ///
871 /// let digest = [
872 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
873 /// 239,
874 /// ];
875 ///
876 /// let sessions = Sessions::new(ObserverTest);
877 ///
878 /// assert!(sessions.get_session(&addr).get_ref().is_none());
879 ///
880 /// sessions.get_integrity(&addr, "test", "test");
881 ///
882 /// let expires = sessions.get_session(&addr).get_ref().unwrap().expires;
883 /// assert!(expires == 600 || expires == 601 || expires == 602);
884 ///
885 /// assert!(sessions.refresh(&addr, 0));
886 ///
887 /// assert!(sessions.get_session(&addr).get_ref().is_none());
888 /// ```
889 pub fn refresh(&self, addr: &SessionAddr, lifetime: u32) -> bool {
890 if lifetime > 3600 {
891 return false;
892 }
893
894 if lifetime == 0 {
895 self.remove_session(&[*addr]);
896 self.remove_nonce(&[*addr]);
897 } else {
898 if let Some(session) = self.state.sessions.write().get_mut(addr) {
899 session.expires = self.timer.get() + lifetime as u64;
900 } else {
901 return false;
902 }
903
904 if let Some(nonce) = self.state.address_nonce_tanle.write().get_mut(addr) {
905 nonce.1 = self.timer.get() + lifetime as u64;
906 }
907 }
908
909 true
910 }
911}
912
913/// The default HashMap is created without allocating capacity. To improve
914/// performance, the turn server needs to pre-allocate the available capacity.
915///
916/// So here the HashMap is rewrapped to allocate a large capacity (number of
917/// ports that can be allocated) at the default creation time as well.
918pub struct Table<K, V>(HashMap<K, V>);
919
920impl<K, V> Default for Table<K, V> {
921 fn default() -> Self {
922 Self(HashMap::with_capacity(PortAllocatePools::capacity()))
923 }
924}
925
926impl<K, V> AsRef<HashMap<K, V>> for Table<K, V> {
927 fn as_ref(&self) -> &HashMap<K, V> {
928 &self.0
929 }
930}
931
932impl<K, V> Deref for Table<K, V> {
933 type Target = HashMap<K, V>;
934
935 fn deref(&self) -> &Self::Target {
936 &self.0
937 }
938}
939
940impl<K, V> DerefMut for Table<K, V> {
941 fn deref_mut(&mut self) -> &mut Self::Target {
942 &mut self.0
943 }
944}
945
946/// Used to lengthen the timing of the release of a readable lock guard and to
947/// provide a more convenient way for external access to the lock's internal
948/// data.
949pub struct ReadLock<'a, 'b, K, R> {
950 key: &'a K,
951 lock: RwLockReadGuard<'b, R>,
952}
953
954impl<'a, 'b, K, V> ReadLock<'a, 'b, K, Table<K, V>>
955where
956 K: Eq + Hash,
957{
958 pub fn get_ref(&self) -> Option<&V> {
959 self.lock.get(self.key)
960 }
961}
962
963/// Bit Flag
964#[derive(PartialEq, Eq)]
965pub enum Bit {
966 Low,
967 High,
968}
969
970/// Random Port
971///
972/// Recently, awareness has been raised about a number of "blind" attacks
973/// (i.e., attacks that can be performed without the need to sniff the
974/// packets that correspond to the transport protocol instance to be
975/// attacked) that can be performed against the Transmission Control
976/// Protocol (TCP) [RFC0793] and similar protocols. The consequences of
977/// these attacks range from throughput reduction to broken connections
978/// or data corruption [RFC5927] [RFC4953] [Watson].
979///
980/// All these attacks rely on the attacker's ability to guess or know the
981/// five-tuple (Protocol, Source Address, Source port, Destination
982/// Address, Destination Port) that identifies the transport protocol
983/// instance to be attacked.
984///
985/// Services are usually located at fixed, "well-known" ports [IANA] at
986/// the host supplying the service (the server). Client applications
987/// connecting to any such service will contact the server by specifying
988/// the server IP address and service port number. The IP address and
989/// port number of the client are normally left unspecified by the client
990/// application and thus are chosen automatically by the client
991/// networking stack. Ports chosen automatically by the networking stack
992/// are known as ephemeral ports [Stevens].
993///
994/// While the server IP address, the well-known port, and the client IP
995/// address may be known by an attacker, the ephemeral port of the client
996/// is usually unknown and must be guessed.
997///
998/// # Test
999///
1000/// ```
1001/// use std::collections::HashSet;
1002/// use turn_server::turn::sessions::*;
1003///
1004/// let mut pool = PortAllocatePools::default();
1005/// let mut ports = HashSet::with_capacity(PortAllocatePools::capacity());
1006///
1007/// while let Some(port) = pool.alloc(None) {
1008/// ports.insert(port);
1009/// }
1010///
1011/// assert_eq!(PortAllocatePools::capacity() + 1, ports.len());
1012/// ```
1013pub struct PortAllocatePools {
1014 pub buckets: Vec<u64>,
1015 allocated: usize,
1016 bit_len: u32,
1017 peak: usize,
1018}
1019
1020impl Default for PortAllocatePools {
1021 fn default() -> Self {
1022 Self {
1023 buckets: vec![0; Self::bucket_size()],
1024 peak: Self::bucket_size() - 1,
1025 bit_len: Self::bit_len(),
1026 allocated: 0,
1027 }
1028 }
1029}
1030
1031impl PortAllocatePools {
1032 /// compute bucket size.
1033 ///
1034 /// # Test
1035 ///
1036 /// ```
1037 /// use turn_server::turn::sessions::*;
1038 ///
1039 /// assert_eq!(PortAllocatePools::bucket_size(), 256);
1040 /// ```
1041 pub fn bucket_size() -> usize {
1042 (Self::capacity() as f32 / 64.0).ceil() as usize
1043 }
1044
1045 /// compute bucket last bit max offset.
1046 ///
1047 /// # Test
1048 ///
1049 /// ```
1050 /// use turn_server::turn::sessions::*;
1051 ///
1052 /// assert_eq!(PortAllocatePools::bit_len(), 63);
1053 /// ```
1054 pub fn bit_len() -> u32 {
1055 (Self::capacity() as f32 % 64.0).ceil() as u32
1056 }
1057
1058 /// get pools capacity.
1059 ///
1060 /// # Test
1061 ///
1062 /// ```
1063 /// use turn_server::turn::sessions::Bit;
1064 /// use turn_server::turn::sessions::PortAllocatePools;
1065 ///
1066 /// assert_eq!(PortAllocatePools::capacity(), 65535 - 49152);
1067 /// ```
1068 pub const fn capacity() -> usize {
1069 65535 - 49152
1070 }
1071
1072 /// get port range.
1073 ///
1074 /// # Test
1075 ///
1076 /// ```
1077 /// use turn_server::turn::sessions::*;
1078 ///
1079 /// assert_eq!(PortAllocatePools::port_range(), 49152..65535);
1080 /// ```
1081 pub const fn port_range() -> Range<u16> {
1082 49152..65535
1083 }
1084
1085 /// get pools allocated size.
1086 ///
1087 /// ```
1088 /// use turn_server::turn::sessions::PortAllocatePools;
1089 ///
1090 /// let mut pools = PortAllocatePools::default();
1091 /// assert_eq!(pools.len(), 0);
1092 ///
1093 /// pools.alloc(None).unwrap();
1094 /// assert_eq!(pools.len(), 1);
1095 /// ```
1096 pub fn len(&self) -> usize {
1097 self.allocated
1098 }
1099
1100 /// get pools allocated size is empty.
1101 ///
1102 /// ```
1103 /// use turn_server::turn::sessions::PortAllocatePools;
1104 ///
1105 /// let mut pools = PortAllocatePools::default();
1106 /// assert_eq!(pools.len(), 0);
1107 /// assert_eq!(pools.is_empty(), true);
1108 /// ```
1109 pub fn is_empty(&self) -> bool {
1110 self.allocated == 0
1111 }
1112
1113 /// random assign a port.
1114 ///
1115 /// # Test
1116 ///
1117 /// ```
1118 /// use turn_server::turn::sessions::PortAllocatePools;
1119 ///
1120 /// let mut pool = PortAllocatePools::default();
1121 ///
1122 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1123 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1124 ///
1125 /// assert!(pool.alloc(None).is_some());
1126 /// ```
1127 pub fn alloc(&mut self, start_index: Option<usize>) -> Option<u16> {
1128 let mut index = None;
1129 let mut start = start_index.unwrap_or_else(|| thread_rng().gen_range(0..self.peak as u16) as usize);
1130
1131 // When the partition lookup has gone through the entire partition list, the
1132 // lookup should be stopped, and the location where it should be stopped is
1133 // recorded here.
1134 let previous = if start == 0 { self.peak } else { start - 1 };
1135
1136 loop {
1137 // Finds the first high position in the partition.
1138 if let Some(i) = {
1139 let bucket = self.buckets[start];
1140 if bucket < u64::MAX {
1141 let offset = bucket.leading_ones();
1142
1143 // Check to see if the jump is beyond the partition list or the lookup exceeds
1144 // the maximum length of the allocation table.
1145 if start == self.peak && offset > self.bit_len {
1146 None
1147 } else {
1148 Some(offset)
1149 }
1150 } else {
1151 None
1152 }
1153 } {
1154 index = Some(i as usize);
1155 break;
1156 }
1157
1158 // As long as it doesn't find it, it continues to re-find it from the next
1159 // partition.
1160 if start == self.peak {
1161 start = 0;
1162 } else {
1163 start += 1;
1164 }
1165
1166 // Already gone through all partitions, lookup failed.
1167 if start == previous {
1168 break;
1169 }
1170 }
1171
1172 // Writes to the partition, marking the current location as already allocated.
1173 let index = index?;
1174 self.set_bit(start, index, Bit::High);
1175 self.allocated += 1;
1176
1177 // The actual port number is calculated from the partition offset position.
1178 let num = (start * 64 + index) as u16;
1179 let port = Self::port_range().start + num;
1180 Some(port)
1181 }
1182
1183 /// write bit flag in the bucket.
1184 ///
1185 /// # Test
1186 ///
1187 /// ```
1188 /// use turn_server::turn::sessions::Bit;
1189 /// use turn_server::turn::sessions::PortAllocatePools;
1190 ///
1191 /// let mut pool = PortAllocatePools::default();
1192 ///
1193 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1194 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1195 ///
1196 /// pool.set_bit(0, 0, Bit::High);
1197 /// pool.set_bit(0, 1, Bit::High);
1198 ///
1199 /// assert_eq!(pool.alloc(Some(0)), Some(49154));
1200 /// assert_eq!(pool.alloc(Some(0)), Some(49155));
1201 /// ```
1202 pub fn set_bit(&mut self, bucket: usize, index: usize, bit: Bit) {
1203 let high_mask = 1 << (63 - index);
1204 let mask = match bit {
1205 Bit::Low => u64::MAX ^ high_mask,
1206 Bit::High => high_mask,
1207 };
1208
1209 let value = self.buckets[bucket];
1210 self.buckets[bucket] = match bit {
1211 Bit::High => value | mask,
1212 Bit::Low => value & mask,
1213 };
1214 }
1215
1216 /// restore port in the buckets.
1217 ///
1218 /// # Test
1219 ///
1220 /// ```
1221 /// use turn_server::turn::sessions::PortAllocatePools;
1222 ///
1223 /// let mut pool = PortAllocatePools::default();
1224 ///
1225 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1226 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1227 ///
1228 /// pool.restore(49152);
1229 /// pool.restore(49153);
1230 ///
1231 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1232 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1233 /// ```
1234 pub fn restore(&mut self, port: u16) {
1235 assert!(Self::port_range().contains(&port));
1236
1237 // Calculate the location in the partition from the port number.
1238 let offset = (port - Self::port_range().start) as usize;
1239 let bucket = offset / 64;
1240 let index = offset - (bucket * 64);
1241
1242 // Gets the bit value in the port position in the partition, if it is low, no
1243 // processing is required.
1244 if {
1245 match (self.buckets[bucket] & (1 << (63 - index))) >> (63 - index) {
1246 0 => Bit::Low,
1247 1 => Bit::High,
1248 _ => panic!(),
1249 }
1250 } == Bit::Low
1251 {
1252 return;
1253 }
1254
1255 self.set_bit(bucket, index, Bit::Low);
1256 self.allocated -= 1;
1257 }
1258}