Skip to main content

turn_server/service/session/
ports.rs

1use std::{fmt::Display, str::FromStr};
2
3use rand::Rng;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct PortRange {
9    start: u16,
10    end: u16,
11}
12
13impl PortRange {
14    pub fn size(&self) -> usize {
15        (self.end - self.start + 1) as usize
16    }
17
18    pub fn contains(&self, port: u16) -> bool {
19        port >= self.start && port <= self.end
20    }
21
22    pub fn start(&self) -> u16 {
23        self.start
24    }
25
26    pub fn end(&self) -> u16 {
27        self.end
28    }
29}
30
31impl Default for PortRange {
32    fn default() -> Self {
33        Self {
34            start: 49152,
35            end: 65535,
36        }
37    }
38}
39
40impl From<std::ops::Range<u16>> for PortRange {
41    fn from(range: std::ops::Range<u16>) -> Self {
42        // Use debug_assert! instead of assert! to avoid panics in release builds
43        // This is a programming error, not a runtime error
44        debug_assert!(range.start <= range.end, "Port range start must be <= end");
45
46        Self {
47            start: range.start,
48            end: range.end,
49        }
50    }
51}
52
53impl Display for PortRange {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(f, "{}..{}", self.start, self.end)
56    }
57}
58
59#[derive(Debug)]
60pub struct PortRangeParseError(String);
61
62impl std::error::Error for PortRangeParseError {}
63
64impl std::fmt::Display for PortRangeParseError {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        write!(f, "{}", self.0)
67    }
68}
69
70impl From<std::num::ParseIntError> for PortRangeParseError {
71    fn from(error: std::num::ParseIntError) -> Self {
72        PortRangeParseError(error.to_string())
73    }
74}
75
76impl FromStr for PortRange {
77    type Err = PortRangeParseError;
78
79    fn from_str(s: &str) -> Result<Self, Self::Err> {
80        let (start, end) = s
81            .split_once("..")
82            .ok_or(PortRangeParseError(s.to_string()))?;
83
84        Ok(Self {
85            start: start.parse()?,
86            end: end.parse()?,
87        })
88    }
89}
90
91impl Serialize for PortRange {
92    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93    where
94        S: serde::Serializer,
95    {
96        serializer.serialize_str(&self.to_string())
97    }
98}
99
100impl<'de> Deserialize<'de> for PortRange {
101    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102    where
103        D: serde::Deserializer<'de>,
104    {
105        let s = String::deserialize(deserializer)?;
106        Self::from_str(&s).map_err(|e| serde::de::Error::custom(e.0))
107    }
108}
109
110/// Bit Flag
111#[derive(PartialEq, Eq)]
112pub enum Bit {
113    Low,
114    High,
115}
116
117/// Random Port
118///
119/// Recently, awareness has been raised about a number of "blind" attacks
120/// (i.e., attacks that can be performed without the need to sniff the
121/// packets that correspond to the transport protocol instance to be
122/// attacked) that can be performed against the Transmission Control
123/// Protocol (TCP) [RFC0793] and similar protocols.  The consequences of
124/// these attacks range from throughput reduction to broken connections
125/// or data corruption [RFC5927] [RFC4953] [Watson].
126///
127/// All these attacks rely on the attacker's ability to guess or know the
128/// five-tuple (Protocol, Source Address, Source port, Destination
129/// Address, Destination Port) that identifies the transport protocol
130/// instance to be attacked.
131///
132/// Services are usually located at fixed, "well-known" ports [IANA] at
133/// the host supplying the service (the server).  Client applications
134/// connecting to any such service will contact the server by specifying
135/// the server IP address and service port number.  The IP address and
136/// port number of the client are normally left unspecified by the client
137/// application and thus are chosen automatically by the client
138/// networking stack.  Ports chosen automatically by the networking stack
139/// are known as ephemeral ports [Stevens].
140///
141/// While the server IP address, the well-known port, and the client IP
142/// address may be known by an attacker, the ephemeral port of the client
143/// is usually unknown and must be guessed.
144///
145/// # Test
146///
147/// ```
148/// use std::collections::HashSet;
149/// use turn_server::service::session::ports::*;
150///
151/// let mut pool = PortAllocator::default();
152/// let mut ports = HashSet::with_capacity(PortAllocator::default().capacity());
153///
154/// while let Some(port) = pool.allocate(None) {
155///     ports.insert(port);
156/// }
157///
158/// assert_eq!(PortAllocator::default().capacity(), ports.len());
159/// ```
160pub struct PortAllocator {
161    port_range: PortRange,
162    buckets: Vec<u64>,
163    allocated: usize,
164    bit_len: u32,
165    max_offset: usize,
166}
167
168impl Default for PortAllocator {
169    fn default() -> Self {
170        Self::new(PortRange::default())
171    }
172}
173
174impl PortAllocator {
175    pub fn new(port_range: PortRange) -> Self {
176        let capacity = port_range.size();
177        let bucket_size = (capacity + 63) / 64;
178        let tail_bits = capacity % 64;
179        let bit_len = if tail_bits == 0 { 64 } else { tail_bits } as u32;
180
181        Self {
182            bit_len,
183            buckets: vec![0; bucket_size],
184            max_offset: bucket_size - 1,
185            allocated: 0,
186            port_range,
187        }
188    }
189
190    /// get pools capacity.
191    ///
192    /// # Test
193    ///
194    /// ```
195    /// use turn_server::service::session::ports::*;
196    ///
197    /// assert_eq!(PortAllocator::default().capacity(), 65535 - 49152 + 1);
198    /// ```
199    pub fn capacity(&self) -> usize {
200        self.port_range.size()
201    }
202
203    /// get port range.
204    ///
205    /// # Test
206    ///
207    /// ```
208    /// use turn_server::service::session::ports::*;
209    ///
210    /// let pool = PortAllocator::default();
211    ///
212    /// assert_eq!(pool.port_range().start(), 49152);
213    /// assert_eq!(pool.port_range().end(), 65535);
214    ///
215    /// let pool = PortAllocator::new((50000..60000).into());
216    ///
217    /// assert_eq!(pool.port_range().start(), 50000);
218    /// assert_eq!(pool.port_range().end(), 60000);
219    /// ```
220    pub fn port_range(&self) -> &PortRange {
221        &self.port_range
222    }
223
224    /// get pools allocated size.
225    ///
226    /// ```
227    /// use turn_server::service::session::ports::*;
228    ///
229    /// let mut pools = PortAllocator::default();
230    /// assert_eq!(pools.len(), 0);
231    ///
232    /// pools.allocate(None).unwrap();
233    /// assert_eq!(pools.len(), 1);
234    /// ```
235    pub fn len(&self) -> usize {
236        self.allocated
237    }
238
239    /// get pools allocated size is empty.
240    ///
241    /// ```
242    /// use turn_server::service::session::ports::*;
243    ///
244    /// let mut pools = PortAllocator::default();
245    /// assert_eq!(pools.len(), 0);
246    /// assert_eq!(pools.is_empty(), true);
247    /// ```
248    pub fn is_empty(&self) -> bool {
249        self.allocated == 0
250    }
251
252    /// random assign a port.
253    ///
254    /// # Test
255    ///
256    /// ```
257    /// use turn_server::service::session::ports::*;
258    ///
259    /// let mut pool = PortAllocator::default();
260    ///
261    /// assert_eq!(pool.allocate(Some(0)), Some(49152));
262    /// assert_eq!(pool.allocate(Some(0)), Some(49153));
263    ///
264    /// assert!(pool.allocate(None).is_some());
265    /// ```
266    pub fn allocate(&mut self, start: Option<usize>) -> Option<u16> {
267        let mut index = None;
268        let mut offset = start.unwrap_or_else(|| rand::rng().random_range(0..=self.max_offset));
269
270        // When the partition lookup has gone through the entire partition list, the
271        // lookup should be stopped, and the location where it should be stopped is
272        // recorded here.
273        let start_offset = offset;
274
275        loop {
276            // Finds the first high position in the partition.
277            if let Some(i) = {
278                let bucket = self.buckets[offset];
279                if bucket < u64::MAX {
280                    let idx = bucket.leading_ones();
281
282                    // Check to see if the jump is beyond the partition list or the lookup exceeds
283                    // the maximum length of the allocation table.
284                    if offset == self.max_offset && idx >= self.bit_len {
285                        None
286                    } else {
287                        Some(idx)
288                    }
289                } else {
290                    None
291                }
292            } {
293                index = Some(i as usize);
294                
295                break;
296            }
297
298            // As long as it doesn't find it, it continues to re-find it from the next
299            // partition.
300            if offset == self.max_offset {
301                offset = 0;
302            } else {
303                offset += 1;
304            }
305
306            // Already gone through all partitions, lookup failed.
307            if offset == start_offset {
308                break;
309            }
310        }
311
312        // Writes to the partition, marking the current location as already allocated.
313        let index = index?;
314        self.set_bit(offset, index, Bit::High);
315        self.allocated += 1;
316
317        // The actual port number is calculated from the partition offset position.
318        let num = (offset * 64 + index) as u16;
319        let port = self.port_range.start + num;
320        Some(port)
321    }
322
323    /// write bit flag in the bucket.
324    ///
325    /// # Test
326    ///
327    /// ```
328    /// use turn_server::service::session::ports::*;
329    ///
330    /// let mut pool = PortAllocator::default();
331    ///
332    /// assert_eq!(pool.allocate(Some(0)), Some(49152));
333    /// assert_eq!(pool.allocate(Some(0)), Some(49153));
334    ///
335    /// pool.set_bit(0, 0, Bit::High);
336    /// pool.set_bit(0, 1, Bit::High);
337    ///
338    /// assert_eq!(pool.allocate(Some(0)), Some(49154));
339    /// assert_eq!(pool.allocate(Some(0)), Some(49155));
340    /// ```
341    pub fn set_bit(&mut self, bucket: usize, index: usize, bit: Bit) {
342        let high_mask = 1 << (63 - index);
343        let mask = match bit {
344            Bit::Low => u64::MAX ^ high_mask,
345            Bit::High => high_mask,
346        };
347
348        let value = self.buckets[bucket];
349        self.buckets[bucket] = match bit {
350            Bit::High => value | mask,
351            Bit::Low => value & mask,
352        };
353    }
354
355    /// deallocate port in the buckets.
356    ///
357    /// # Test
358    ///
359    /// ```
360    /// use turn_server::service::session::ports::*;
361    ///
362    /// let mut pool = PortAllocator::default();
363    ///
364    /// assert_eq!(pool.allocate(Some(0)), Some(49152));
365    /// assert_eq!(pool.allocate(Some(0)), Some(49153));
366    ///
367    /// pool.deallocate(49152);
368    /// pool.deallocate(49153);
369    ///
370    /// assert_eq!(pool.allocate(Some(0)), Some(49152));
371    /// assert_eq!(pool.allocate(Some(0)), Some(49153));
372    /// ```
373    pub fn deallocate(&mut self, port: u16) {
374        assert!(self.port_range.contains(port));
375
376        // Calculate the location in the partition from the port number.
377        let offset = (port - self.port_range.start) as usize;
378        let bucket = offset / 64;
379        let index = offset - (bucket * 64);
380
381        // Gets the bit value in the port position in the partition, if it is low, no
382        // processing is required.
383        let bit = match (self.buckets[bucket] & (1 << (63 - index))) >> (63 - index) {
384            0 => Bit::Low,
385            1 => Bit::High,
386            _ => unreachable!("Bit value can only be 0 or 1"),
387        };
388        
389        if bit == Bit::Low {
390            return;
391        }
392
393        self.set_bit(bucket, index, Bit::Low);
394        self.allocated -= 1;
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    
402    use std::collections::HashSet;
403
404    #[test]
405    fn allocate_all_ports_without_gaps_includes_tail_bits() {
406        let range = PortRange::from(50000..50069);
407        let mut pool = PortAllocator::new(range);
408        let mut ports = HashSet::new();
409
410        while let Some(port) = pool.allocate(None) {
411            assert!(range.contains(port));
412            assert!(ports.insert(port));
413        }
414
415        assert_eq!(pool.capacity(), ports.len());
416        assert_eq!(range.start(), *ports.iter().min().unwrap());
417        assert_eq!(range.end(), *ports.iter().max().unwrap());
418    }
419
420    #[test]
421    fn random_allocation_varies_first_port() {
422        let range = PortRange::from(50000..50127);
423        let mut first_ports = HashSet::new();
424
425        for _ in 0..128 {
426            let mut pool = PortAllocator::new(range);
427            if let Some(port) = pool.allocate(None) {
428                first_ports.insert(port);
429            }
430        }
431
432        assert!(first_ports.len() > 1);
433    }
434}