turn_server/service/session/ports.rs
1use std::{fmt::Display, str::FromStr};
2
3use rand::Rng;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct PortRange {
9 start: u16,
10 end: u16,
11}
12
13impl PortRange {
14 pub fn size(&self) -> usize {
15 (self.end - self.start + 1) as usize
16 }
17
18 pub fn contains(&self, port: u16) -> bool {
19 port >= self.start && port <= self.end
20 }
21
22 pub fn start(&self) -> u16 {
23 self.start
24 }
25
26 pub fn end(&self) -> u16 {
27 self.end
28 }
29}
30
31impl Default for PortRange {
32 fn default() -> Self {
33 Self {
34 start: 49152,
35 end: 65535,
36 }
37 }
38}
39
40impl From<std::ops::Range<u16>> for PortRange {
41 fn from(range: std::ops::Range<u16>) -> Self {
42 // Use debug_assert! instead of assert! to avoid panics in release builds
43 // This is a programming error, not a runtime error
44 debug_assert!(range.start <= range.end, "Port range start must be <= end");
45
46 Self {
47 start: range.start,
48 end: range.end,
49 }
50 }
51}
52
53impl Display for PortRange {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "{}..{}", self.start, self.end)
56 }
57}
58
59#[derive(Debug)]
60pub struct PortRangeParseError(String);
61
62impl std::error::Error for PortRangeParseError {}
63
64impl std::fmt::Display for PortRangeParseError {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 write!(f, "{}", self.0)
67 }
68}
69
70impl From<std::num::ParseIntError> for PortRangeParseError {
71 fn from(error: std::num::ParseIntError) -> Self {
72 PortRangeParseError(error.to_string())
73 }
74}
75
76impl FromStr for PortRange {
77 type Err = PortRangeParseError;
78
79 fn from_str(s: &str) -> Result<Self, Self::Err> {
80 let (start, end) = s
81 .split_once("..")
82 .ok_or(PortRangeParseError(s.to_string()))?;
83
84 Ok(Self {
85 start: start.parse()?,
86 end: end.parse()?,
87 })
88 }
89}
90
91impl Serialize for PortRange {
92 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93 where
94 S: serde::Serializer,
95 {
96 serializer.serialize_str(&self.to_string())
97 }
98}
99
100impl<'de> Deserialize<'de> for PortRange {
101 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102 where
103 D: serde::Deserializer<'de>,
104 {
105 let s = String::deserialize(deserializer)?;
106 Self::from_str(&s).map_err(|e| serde::de::Error::custom(e.0))
107 }
108}
109
110/// Bit Flag
111#[derive(PartialEq, Eq)]
112pub enum Bit {
113 Low,
114 High,
115}
116
117/// Random Port
118///
119/// Recently, awareness has been raised about a number of "blind" attacks
120/// (i.e., attacks that can be performed without the need to sniff the
121/// packets that correspond to the transport protocol instance to be
122/// attacked) that can be performed against the Transmission Control
123/// Protocol (TCP) [RFC0793] and similar protocols. The consequences of
124/// these attacks range from throughput reduction to broken connections
125/// or data corruption [RFC5927] [RFC4953] [Watson].
126///
127/// All these attacks rely on the attacker's ability to guess or know the
128/// five-tuple (Protocol, Source Address, Source port, Destination
129/// Address, Destination Port) that identifies the transport protocol
130/// instance to be attacked.
131///
132/// Services are usually located at fixed, "well-known" ports [IANA] at
133/// the host supplying the service (the server). Client applications
134/// connecting to any such service will contact the server by specifying
135/// the server IP address and service port number. The IP address and
136/// port number of the client are normally left unspecified by the client
137/// application and thus are chosen automatically by the client
138/// networking stack. Ports chosen automatically by the networking stack
139/// are known as ephemeral ports [Stevens].
140///
141/// While the server IP address, the well-known port, and the client IP
142/// address may be known by an attacker, the ephemeral port of the client
143/// is usually unknown and must be guessed.
144///
145/// # Test
146///
147/// ```
148/// use std::collections::HashSet;
149/// use turn_server::service::session::ports::*;
150///
151/// let mut pool = PortAllocator::default();
152/// let mut ports = HashSet::with_capacity(PortAllocator::default().capacity());
153///
154/// while let Some(port) = pool.allocate(None) {
155/// ports.insert(port);
156/// }
157///
158/// assert_eq!(PortAllocator::default().capacity(), ports.len());
159/// ```
160pub struct PortAllocator {
161 port_range: PortRange,
162 buckets: Vec<u64>,
163 allocated: usize,
164 bit_len: u32,
165 max_offset: usize,
166}
167
168impl Default for PortAllocator {
169 fn default() -> Self {
170 Self::new(PortRange::default())
171 }
172}
173
174impl PortAllocator {
175 pub fn new(port_range: PortRange) -> Self {
176 let capacity = port_range.size();
177 let bucket_size = (capacity + 63) / 64;
178 let tail_bits = capacity % 64;
179 let bit_len = if tail_bits == 0 { 64 } else { tail_bits } as u32;
180
181 Self {
182 bit_len,
183 buckets: vec![0; bucket_size],
184 max_offset: bucket_size - 1,
185 allocated: 0,
186 port_range,
187 }
188 }
189
190 /// get pools capacity.
191 ///
192 /// # Test
193 ///
194 /// ```
195 /// use turn_server::service::session::ports::*;
196 ///
197 /// assert_eq!(PortAllocator::default().capacity(), 65535 - 49152 + 1);
198 /// ```
199 pub fn capacity(&self) -> usize {
200 self.port_range.size()
201 }
202
203 /// get port range.
204 ///
205 /// # Test
206 ///
207 /// ```
208 /// use turn_server::service::session::ports::*;
209 ///
210 /// let pool = PortAllocator::default();
211 ///
212 /// assert_eq!(pool.port_range().start(), 49152);
213 /// assert_eq!(pool.port_range().end(), 65535);
214 ///
215 /// let pool = PortAllocator::new((50000..60000).into());
216 ///
217 /// assert_eq!(pool.port_range().start(), 50000);
218 /// assert_eq!(pool.port_range().end(), 60000);
219 /// ```
220 pub fn port_range(&self) -> &PortRange {
221 &self.port_range
222 }
223
224 /// get pools allocated size.
225 ///
226 /// ```
227 /// use turn_server::service::session::ports::*;
228 ///
229 /// let mut pools = PortAllocator::default();
230 /// assert_eq!(pools.len(), 0);
231 ///
232 /// pools.allocate(None).unwrap();
233 /// assert_eq!(pools.len(), 1);
234 /// ```
235 pub fn len(&self) -> usize {
236 self.allocated
237 }
238
239 /// get pools allocated size is empty.
240 ///
241 /// ```
242 /// use turn_server::service::session::ports::*;
243 ///
244 /// let mut pools = PortAllocator::default();
245 /// assert_eq!(pools.len(), 0);
246 /// assert_eq!(pools.is_empty(), true);
247 /// ```
248 pub fn is_empty(&self) -> bool {
249 self.allocated == 0
250 }
251
252 /// random assign a port.
253 ///
254 /// # Test
255 ///
256 /// ```
257 /// use turn_server::service::session::ports::*;
258 ///
259 /// let mut pool = PortAllocator::default();
260 ///
261 /// assert_eq!(pool.allocate(Some(0)), Some(49152));
262 /// assert_eq!(pool.allocate(Some(0)), Some(49153));
263 ///
264 /// assert!(pool.allocate(None).is_some());
265 /// ```
266 pub fn allocate(&mut self, start: Option<usize>) -> Option<u16> {
267 let mut index = None;
268 let mut offset = start.unwrap_or_else(|| rand::rng().random_range(0..=self.max_offset));
269
270 // When the partition lookup has gone through the entire partition list, the
271 // lookup should be stopped, and the location where it should be stopped is
272 // recorded here.
273 let start_offset = offset;
274
275 loop {
276 // Finds the first high position in the partition.
277 if let Some(i) = {
278 let bucket = self.buckets[offset];
279 if bucket < u64::MAX {
280 let idx = bucket.leading_ones();
281
282 // Check to see if the jump is beyond the partition list or the lookup exceeds
283 // the maximum length of the allocation table.
284 if offset == self.max_offset && idx >= self.bit_len {
285 None
286 } else {
287 Some(idx)
288 }
289 } else {
290 None
291 }
292 } {
293 index = Some(i as usize);
294
295 break;
296 }
297
298 // As long as it doesn't find it, it continues to re-find it from the next
299 // partition.
300 if offset == self.max_offset {
301 offset = 0;
302 } else {
303 offset += 1;
304 }
305
306 // Already gone through all partitions, lookup failed.
307 if offset == start_offset {
308 break;
309 }
310 }
311
312 // Writes to the partition, marking the current location as already allocated.
313 let index = index?;
314 self.set_bit(offset, index, Bit::High);
315 self.allocated += 1;
316
317 // The actual port number is calculated from the partition offset position.
318 let num = (offset * 64 + index) as u16;
319 let port = self.port_range.start + num;
320 Some(port)
321 }
322
323 /// write bit flag in the bucket.
324 ///
325 /// # Test
326 ///
327 /// ```
328 /// use turn_server::service::session::ports::*;
329 ///
330 /// let mut pool = PortAllocator::default();
331 ///
332 /// assert_eq!(pool.allocate(Some(0)), Some(49152));
333 /// assert_eq!(pool.allocate(Some(0)), Some(49153));
334 ///
335 /// pool.set_bit(0, 0, Bit::High);
336 /// pool.set_bit(0, 1, Bit::High);
337 ///
338 /// assert_eq!(pool.allocate(Some(0)), Some(49154));
339 /// assert_eq!(pool.allocate(Some(0)), Some(49155));
340 /// ```
341 pub fn set_bit(&mut self, bucket: usize, index: usize, bit: Bit) {
342 let high_mask = 1 << (63 - index);
343 let mask = match bit {
344 Bit::Low => u64::MAX ^ high_mask,
345 Bit::High => high_mask,
346 };
347
348 let value = self.buckets[bucket];
349 self.buckets[bucket] = match bit {
350 Bit::High => value | mask,
351 Bit::Low => value & mask,
352 };
353 }
354
355 /// deallocate port in the buckets.
356 ///
357 /// # Test
358 ///
359 /// ```
360 /// use turn_server::service::session::ports::*;
361 ///
362 /// let mut pool = PortAllocator::default();
363 ///
364 /// assert_eq!(pool.allocate(Some(0)), Some(49152));
365 /// assert_eq!(pool.allocate(Some(0)), Some(49153));
366 ///
367 /// pool.deallocate(49152);
368 /// pool.deallocate(49153);
369 ///
370 /// assert_eq!(pool.allocate(Some(0)), Some(49152));
371 /// assert_eq!(pool.allocate(Some(0)), Some(49153));
372 /// ```
373 pub fn deallocate(&mut self, port: u16) {
374 assert!(self.port_range.contains(port));
375
376 // Calculate the location in the partition from the port number.
377 let offset = (port - self.port_range.start) as usize;
378 let bucket = offset / 64;
379 let index = offset - (bucket * 64);
380
381 // Gets the bit value in the port position in the partition, if it is low, no
382 // processing is required.
383 let bit = match (self.buckets[bucket] & (1 << (63 - index))) >> (63 - index) {
384 0 => Bit::Low,
385 1 => Bit::High,
386 _ => unreachable!("Bit value can only be 0 or 1"),
387 };
388
389 if bit == Bit::Low {
390 return;
391 }
392
393 self.set_bit(bucket, index, Bit::Low);
394 self.allocated -= 1;
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 use std::collections::HashSet;
403
404 #[test]
405 fn allocate_all_ports_without_gaps_includes_tail_bits() {
406 let range = PortRange::from(50000..50069);
407 let mut pool = PortAllocator::new(range);
408 let mut ports = HashSet::new();
409
410 while let Some(port) = pool.allocate(None) {
411 assert!(range.contains(port));
412 assert!(ports.insert(port));
413 }
414
415 assert_eq!(pool.capacity(), ports.len());
416 assert_eq!(range.start(), *ports.iter().min().unwrap());
417 assert_eq!(range.end(), *ports.iter().max().unwrap());
418 }
419
420 #[test]
421 fn random_allocation_varies_first_port() {
422 let range = PortRange::from(50000..50127);
423 let mut first_ports = HashSet::new();
424
425 for _ in 0..128 {
426 let mut pool = PortAllocator::new(range);
427 if let Some(port) = pool.allocate(None) {
428 first_ports.insert(port);
429 }
430 }
431
432 assert!(first_ports.len() > 1);
433 }
434}