utp_socket/
reorder_buffer.rs

1use crate::utp_packet::Packet;
2
3#[derive(Debug)]
4pub struct ReorderBuffer {
5    buffer: Box<[Option<Packet>]>,
6    first: usize,
7    last: usize,
8    size: usize,
9}
10
11impl ReorderBuffer {
12    // size in PACKETS
13    pub fn new(size: usize) -> Self {
14        ReorderBuffer {
15            // TODO vec not needed
16            buffer: vec![Option::<Packet>::None; size].into_boxed_slice(),
17            first: 0,
18            last: 0,
19            size: 0,
20        }
21    }
22
23    pub fn insert(&mut self, packet: Packet) {
24        if self.get(packet.header.seq_nr).is_some() {
25            return;
26        }
27
28        let position = packet.header.seq_nr as i32;
29
30        let (first_val, last_val) = if let (Some(first_val), Some(last_val)) = (
31            self.buffer[self.first]
32                .as_ref()
33                .map(|p| p.header.seq_nr as i32),
34            self.buffer[self.last]
35                .as_ref()
36                .map(|p| p.header.seq_nr as i32),
37        ) {
38            (first_val, last_val)
39        } else {
40            self.size += packet.data.len();
41            self.buffer[self.first] = Some(packet);
42            debug_assert!(self.last == self.first);
43            return;
44        };
45
46        self.size += packet.data.len();
47        match position.cmp(&first_val) {
48            std::cmp::Ordering::Less => {
49                // If the available space is less than the distance to
50                // the first value we need to realloc
51                if self.buffer.len() as i32 - (last_val - first_val) <= first_val - position {
52                    // Ensure current span + new value fits
53                    // by calculating dist between position (< first_val)
54                    // and last_val
55                    self.resize(1 + (last_val - position) as usize);
56                    // This is conceptually the same as what's done in the
57                    // else branch but avoids the rem_euclid operation since
58                    // we know that first always = 0 after resizing
59                    let new_first = self.buffer.len() - (first_val - position) as usize;
60                    self.buffer[new_first] = Some(packet);
61                    self.first = new_first;
62                } else {
63                    // there is capacity for it
64                    let new_first = (self.first as i32 - (first_val - position))
65                        .rem_euclid(self.buffer.len() as i32);
66                    self.buffer[new_first as usize] = Some(packet);
67                    self.first = new_first as usize;
68                }
69            }
70            std::cmp::Ordering::Greater => {
71                // there is capacity for it
72                if first_val as usize + self.buffer.len() > position as usize {
73                    let index = (self.first + (position - first_val) as usize) % self.buffer.len();
74                    if last_val < packet.header.seq_nr as i32 {
75                        self.last = index;
76                    }
77                    self.buffer[index] = Some(packet);
78                } else {
79                    self.resize(1 + (position - first_val) as usize);
80                    let index = self.first + (position - first_val) as usize;
81                    if last_val < packet.header.seq_nr as i32 {
82                        self.last = index;
83                    }
84                    self.buffer[index] = Some(packet);
85                }
86            }
87            std::cmp::Ordering::Equal => unreachable!(),
88        }
89    }
90
91    fn resize(&mut self, min_size: usize) {
92        let new_size = std::cmp::max(min_size, self.buffer.len() * 2);
93        let mut buf_new = vec![Option::<Packet>::None; new_size].into_boxed_slice();
94        let first_part = &self.buffer[self.first..];
95        let second_part = &self.buffer[..self.first];
96        // Can't use ptr copy since Bytes isn't copy
97        buf_new[0..first_part.len()].clone_from_slice(first_part);
98        buf_new[first_part.len()..first_part.len() + second_part.len()]
99            .clone_from_slice(second_part);
100        let old_cap = self.buffer.len();
101        self.buffer = buf_new;
102        match self.last.cmp(&self.first) {
103            // Move it to an unrwapped position
104            std::cmp::Ordering::Less => self.last += old_cap - self.first,
105            // Move it to the new first position
106            std::cmp::Ordering::Equal => self.last = 0,
107            // Move first position steps down since first is moved to 0
108            std::cmp::Ordering::Greater => self.last -= self.first,
109        }
110        self.first = 0;
111    }
112
113    #[inline]
114    fn index_of(&self, position: i32) -> Option<usize> {
115        let first_val = self.buffer[self.first].as_ref()?.header.seq_nr as i32;
116        Some(
117            (self.first as i32 + (position - first_val)).rem_euclid(self.buffer.len() as i32)
118                as usize,
119        )
120    }
121
122    #[inline]
123    pub fn get(&self, position: u16) -> Option<&Packet> {
124        let index = self.index_of(position as i32)?;
125        self.buffer[index]
126            .as_ref()
127            .filter(|packet| packet.header.seq_nr == position)
128    }
129
130    // TODO make sequential removal more efficient
131    pub fn remove(&mut self, position: u16) -> Option<Packet> {
132        let index = self.index_of(position as i32)?;
133
134        let maybe_packet = self.buffer[index].take();
135        if let Some(packet) = maybe_packet.as_ref() {
136            if packet.header.seq_nr == position {
137                if self.first == index {
138                    // Only one element in the buffer
139                    if self.buffer[self.last].is_none() {
140                        self.first = 0;
141                        self.last = 0;
142                    } else {
143                        // find new first index
144                        self.first += 1;
145                        self.first %= self.buffer.len();
146                        while self.first != self.last && self.buffer[self.first].is_none() {
147                            self.first += 1;
148                            self.first %= self.buffer.len();
149                        }
150                    }
151                } else if self.last == index {
152                    // Only one element in the buffer
153                    if self.buffer[self.first].is_none() {
154                        self.first = 0;
155                        self.last = 0;
156                    } else {
157                        // find new last index
158                        self.last =
159                            (self.last as i32 - 1).rem_euclid(self.buffer.len() as i32) as usize;
160                        while self.first != self.last && self.buffer[self.last].is_none() {
161                            self.last = (self.last as i32 - 1).rem_euclid(self.buffer.len() as i32)
162                                as usize;
163                        }
164                    }
165                }
166                if let Some(packet) = maybe_packet {
167                    self.size -= packet.data.len();
168                    return Some(packet);
169                } else {
170                    return None;
171                }
172            } else {
173                self.buffer[index] = maybe_packet;
174            }
175        }
176        None
177    }
178
179    pub fn is_empty(&self) -> bool {
180        let empty = self.buffer[self.first].is_none();
181        // santiy check
182        if empty {
183            debug_assert!(self.buffer[self.last].is_none());
184        }
185        empty
186    }
187
188    // Lenght in the number of packets
189    #[inline]
190    pub fn len(&self) -> usize {
191        if self.is_empty() {
192            0
193        } else {
194            self.iter().count()
195        }
196    }
197
198    // Size in bytes
199    #[inline(always)]
200    pub fn size(&self) -> usize {
201        self.size
202    }
203
204    // TODO remove allocations from this
205    pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &'a Packet> + 'a> {
206        if self.first <= self.last {
207            Box::new(
208                self.buffer[self.first..self.last + 1]
209                    .iter()
210                    .filter_map(|maybe_packet| maybe_packet.as_ref()),
211            )
212        } else {
213            Box::new(
214                self.buffer[self.first..]
215                    .iter()
216                    .chain(self.buffer[..self.last + 1].iter())
217                    .filter_map(|maybe_packet| maybe_packet.as_ref()),
218            )
219        }
220    }
221}
222
223#[cfg(test)]
224mod test {
225
226    use bytes::Bytes;
227
228    use crate::utp_packet::PacketHeader;
229
230    use super::*;
231
232    #[test]
233    fn insertion_orderd() {
234        let data = vec![
235            Packet {
236                header: PacketHeader {
237                    seq_nr: 1,
238                    ack_nr: 0,
239                    conn_id: 0,
240                    packet_type: crate::utp_packet::PacketType::Data,
241                    timestamp_microseconds: 0,
242                    timestamp_difference_microseconds: 0,
243                    wnd_size: 0,
244                    extension: 0,
245                },
246                data: Bytes::new(),
247            },
248            Packet {
249                header: PacketHeader {
250                    seq_nr: 2,
251                    ack_nr: 0,
252                    conn_id: 0,
253                    packet_type: crate::utp_packet::PacketType::Data,
254                    timestamp_microseconds: 0,
255                    timestamp_difference_microseconds: 0,
256                    wnd_size: 0,
257                    extension: 0,
258                },
259                data: Bytes::new(),
260            },
261            Packet {
262                header: PacketHeader {
263                    seq_nr: 3,
264                    ack_nr: 0,
265                    conn_id: 0,
266                    packet_type: crate::utp_packet::PacketType::Data,
267                    timestamp_microseconds: 0,
268                    timestamp_difference_microseconds: 0,
269                    wnd_size: 0,
270                    extension: 0,
271                },
272                data: Bytes::new(),
273            },
274        ];
275
276        let mut buffer = ReorderBuffer::new(256);
277
278        for packet in data.into_iter() {
279            buffer.insert(packet);
280        }
281
282        for seq_nr in 1..3 {
283            let packet = buffer.get(seq_nr).unwrap();
284            assert_eq!(packet.header.seq_nr, seq_nr);
285        }
286    }
287
288    #[test]
289    fn insertion_unorderd() {
290        let data = vec![
291            Packet {
292                header: PacketHeader {
293                    seq_nr: 3,
294                    ack_nr: 0,
295                    conn_id: 0,
296                    packet_type: crate::utp_packet::PacketType::Data,
297                    timestamp_microseconds: 0,
298                    timestamp_difference_microseconds: 0,
299                    wnd_size: 0,
300                    extension: 0,
301                },
302                data: Bytes::new(),
303            },
304            Packet {
305                header: PacketHeader {
306                    seq_nr: 1,
307                    ack_nr: 0,
308                    conn_id: 0,
309                    packet_type: crate::utp_packet::PacketType::Data,
310                    timestamp_microseconds: 0,
311                    timestamp_difference_microseconds: 0,
312                    wnd_size: 0,
313                    extension: 0,
314                },
315                data: Bytes::new(),
316            },
317            Packet {
318                header: PacketHeader {
319                    seq_nr: 4,
320                    ack_nr: 0,
321                    conn_id: 0,
322                    packet_type: crate::utp_packet::PacketType::Data,
323                    timestamp_microseconds: 0,
324                    timestamp_difference_microseconds: 0,
325                    wnd_size: 0,
326                    extension: 0,
327                },
328                data: Bytes::new(),
329            },
330        ];
331
332        let mut buffer = ReorderBuffer::new(256);
333
334        for packet in data.into_iter() {
335            buffer.insert(packet);
336        }
337
338        let packet = buffer.get(1).unwrap();
339        assert_eq!(packet.header.seq_nr as usize, 1);
340        let packet = buffer.get(3).unwrap();
341        assert_eq!(packet.header.seq_nr as usize, 3);
342        let packet = buffer.get(4).unwrap();
343        assert_eq!(packet.header.seq_nr as usize, 4);
344        assert_eq!(buffer.len(), 3);
345    }
346
347    #[test]
348    fn insertion_unorderd_large_gap() {
349        let data = vec![
350            Packet {
351                header: PacketHeader {
352                    seq_nr: 253,
353                    ack_nr: 0,
354                    conn_id: 0,
355                    packet_type: crate::utp_packet::PacketType::Data,
356                    timestamp_microseconds: 0,
357                    timestamp_difference_microseconds: 0,
358                    wnd_size: 0,
359                    extension: 0,
360                },
361                data: Bytes::new(),
362            },
363            Packet {
364                header: PacketHeader {
365                    seq_nr: 747,
366                    ack_nr: 0,
367                    conn_id: 0,
368                    packet_type: crate::utp_packet::PacketType::Data,
369                    timestamp_microseconds: 0,
370                    timestamp_difference_microseconds: 0,
371                    wnd_size: 0,
372                    extension: 0,
373                },
374                data: Bytes::new(),
375            },
376            Packet {
377                header: PacketHeader {
378                    seq_nr: 108,
379                    ack_nr: 0,
380                    conn_id: 0,
381                    packet_type: crate::utp_packet::PacketType::Data,
382                    timestamp_microseconds: 0,
383                    timestamp_difference_microseconds: 0,
384                    wnd_size: 0,
385                    extension: 0,
386                },
387                data: Bytes::new(),
388            },
389        ];
390
391        let mut buffer = ReorderBuffer::new(256);
392
393        for packet in data.into_iter() {
394            buffer.insert(packet);
395        }
396
397        let packet = buffer.get(108).unwrap();
398        assert_eq!(packet.header.seq_nr as usize, 108);
399        let packet = buffer.get(253).unwrap();
400        assert_eq!(packet.header.seq_nr as usize, 253);
401        let packet = buffer.get(747).unwrap();
402        assert_eq!(packet.header.seq_nr as usize, 747);
403        assert_eq!(buffer.len(), 3);
404    }
405
406    #[test]
407    fn insertion_orderd_large_gap() {
408        let data = vec![
409            Packet {
410                header: PacketHeader {
411                    seq_nr: 245,
412                    ack_nr: 0,
413                    conn_id: 0,
414                    packet_type: crate::utp_packet::PacketType::Data,
415                    timestamp_microseconds: 0,
416                    timestamp_difference_microseconds: 0,
417                    wnd_size: 0,
418                    extension: 0,
419                },
420                data: Bytes::new(),
421            },
422            Packet {
423                header: PacketHeader {
424                    seq_nr: 922,
425                    ack_nr: 0,
426                    conn_id: 0,
427                    packet_type: crate::utp_packet::PacketType::Data,
428                    timestamp_microseconds: 0,
429                    timestamp_difference_microseconds: 0,
430                    wnd_size: 0,
431                    extension: 0,
432                },
433                data: Bytes::new(),
434            },
435        ];
436
437        let mut buffer = ReorderBuffer::new(256);
438
439        for packet in data.into_iter() {
440            buffer.insert(packet);
441        }
442
443        let packet = buffer.get(245).unwrap();
444        assert_eq!(packet.header.seq_nr as usize, 245);
445        let packet = buffer.get(922).unwrap();
446        assert_eq!(packet.header.seq_nr as usize, 922);
447        assert_eq!(buffer.len(), 2);
448    }
449
450    #[test]
451    fn index_collision() {
452        // Tests the case where the seq_nr
453        // mod capacity yields an existing entry
454        // which doesn't match the one being inserted
455        // caught by fuzzing
456        let mut buffer = ReorderBuffer::new(64);
457        buffer.insert(Packet {
458            header: PacketHeader {
459                seq_nr: 2570,
460                ack_nr: 0,
461                conn_id: 0,
462                packet_type: crate::utp_packet::PacketType::Data,
463                timestamp_microseconds: 0,
464                timestamp_difference_microseconds: 0,
465                wnd_size: 0,
466                extension: 0,
467            },
468            data: Bytes::new(),
469        });
470        buffer.insert(Packet {
471            header: PacketHeader {
472                seq_nr: 2698,
473                ack_nr: 0,
474                conn_id: 0,
475                packet_type: crate::utp_packet::PacketType::Data,
476                timestamp_microseconds: 0,
477                timestamp_difference_microseconds: 0,
478                wnd_size: 0,
479                extension: 0,
480            },
481            data: Bytes::new(),
482        });
483        let packet = buffer.get(2570).unwrap();
484        assert_eq!(packet.header.seq_nr, 2570);
485        let packet = buffer.get(2698).unwrap();
486        assert_eq!(packet.header.seq_nr, 2698);
487        assert_eq!(buffer.len(), 2);
488    }
489
490    #[test]
491    fn resizing() {
492        // Ensures the existing span + the new value
493        // fits in the resized buffer. Caught by fuzzing
494        let input = [25413, 25392, 16744, 2607];
495        let mut buffer = ReorderBuffer::new(64);
496
497        for seq_nr in input.iter() {
498            buffer.insert(Packet {
499                header: PacketHeader {
500                    seq_nr: *seq_nr,
501                    ack_nr: 0,
502                    conn_id: 0,
503                    packet_type: crate::utp_packet::PacketType::Data,
504                    timestamp_microseconds: 0,
505                    timestamp_difference_microseconds: 0,
506                    wnd_size: 0,
507                    extension: 0,
508                },
509                data: Bytes::new(),
510            });
511        }
512
513        for seq_nr in input.iter() {
514            let packet = buffer.get(*seq_nr).unwrap();
515            assert_eq!(packet.header.seq_nr, *seq_nr);
516        }
517        assert_eq!(buffer.len(), 4);
518    }
519
520    #[test]
521    fn removal() {
522        let input = [3, 6, 7];
523        let mut buffer = ReorderBuffer::new(64);
524
525        for seq_nr in input.iter() {
526            buffer.insert(Packet {
527                header: PacketHeader {
528                    seq_nr: *seq_nr,
529                    ack_nr: 0,
530                    conn_id: 0,
531                    packet_type: crate::utp_packet::PacketType::Data,
532                    timestamp_microseconds: 0,
533                    timestamp_difference_microseconds: 0,
534                    wnd_size: 0,
535                    extension: 0,
536                },
537                data: Bytes::new(),
538            });
539        }
540
541        assert!(buffer.get(8).is_none());
542        assert!(buffer.get(5).is_none());
543        assert_eq!(buffer.len(), 3);
544
545        for seq_nr in input.iter() {
546            let packet = buffer.remove(*seq_nr).unwrap();
547            assert_eq!(packet.header.seq_nr, *seq_nr);
548        }
549
550        assert_eq!(buffer.len(), 0);
551        for seq_nr in input.iter() {
552            assert!(buffer.get(*seq_nr).is_none());
553        }
554    }
555
556    #[test]
557    fn removal_of_last_with_wraparound() {
558        // Ensures calculating the new last index
559        // works as expected when wrapping around
560        // found by fuzzing
561        let input = [57078, 56842];
562        let mut buffer = ReorderBuffer::new(64);
563
564        for seq_nr in input.iter() {
565            buffer.insert(Packet {
566                header: PacketHeader {
567                    seq_nr: *seq_nr,
568                    ack_nr: 0,
569                    conn_id: 0,
570                    packet_type: crate::utp_packet::PacketType::Data,
571                    timestamp_microseconds: 0,
572                    timestamp_difference_microseconds: 0,
573                    wnd_size: 0,
574                    extension: 0,
575                },
576                data: Bytes::new(),
577            });
578        }
579        assert_eq!(buffer.len(), 2);
580        for seq_nr in input.iter() {
581            let packet = buffer.remove(*seq_nr).unwrap();
582            assert_eq!(packet.header.seq_nr, *seq_nr);
583        }
584
585        assert_eq!(buffer.len(), 0);
586        for seq_nr in input.iter() {
587            assert!(buffer.get(*seq_nr).is_none());
588        }
589    }
590
591    #[test]
592    fn removal_of_last_with_wraparound_v2() {
593        // Ensures calculating the new last index
594        // works as expected when wrapping around
595        // found by fuzzing
596        let input = [22320, 22370, 14126];
597        let mut buffer = ReorderBuffer::new(64);
598
599        for seq_nr in input.iter() {
600            buffer.insert(Packet {
601                header: PacketHeader {
602                    seq_nr: *seq_nr,
603                    ack_nr: 0,
604                    conn_id: 0,
605                    packet_type: crate::utp_packet::PacketType::Data,
606                    timestamp_microseconds: 0,
607                    timestamp_difference_microseconds: 0,
608                    wnd_size: 0,
609                    extension: 0,
610                },
611                data: Bytes::new(),
612            });
613        }
614
615        assert_eq!(buffer.len(), 3);
616        for seq_nr in input.iter() {
617            let packet = buffer.remove(*seq_nr).unwrap();
618            assert_eq!(packet.header.seq_nr, *seq_nr);
619        }
620
621        assert_eq!(buffer.len(), 0);
622        for seq_nr in input.iter() {
623            assert!(buffer.get(*seq_nr).is_none());
624        }
625    }
626}