s2n_quic_core/packet/number/
map.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::packet::number::{PacketNumber, PacketNumberRange, PacketNumberSpace};
5use alloc::{boxed::Box, vec::Vec};
6use core::fmt;
7
8/// A data structure for tracking packets that are pending acknowledgement
9///
10/// The following assumptions are made and exploited
11///
12/// * Packet numbers are monotonically generated and inserted
13/// * Packet numbers will mostly be removed in ranges
14/// * Packet numbers that are deemed lost will also be removed and retransmitted
15///
16/// This is implemented as a buffer ring with a moving range for the lower and upper bound of
17/// contained packet numbers. The following example illustrates how each field tracks state:
18///
19/// ```ignore
20/// packets = [ PN(2), None, PN(0), PN(1) ]
21///                           ^ index = 2
22/// start = PN(0)
23/// end = PN(2)
24/// ```
25///
26/// Upon inserting `PN(3)` the state is now:
27///
28/// ```ignore
29/// packets = [ PN(2), PN(3), PN(0), PN(1) ]
30///                           ^ index = 2
31/// start = PN(0)
32/// end = PN(3)
33/// ```
34///
35/// Upon removing `PN(0)` the state is now:
36///
37/// ```ignore
38/// packets = [ PN(2), PN(3), None, PN(1) ]
39///                                 ^ index = 3
40/// start = PN(1)
41/// end = PN(3)
42/// ```
43#[derive(Clone)]
44pub struct Map<V> {
45    /// The stored values for each packet number
46    values: Box<[Option<V>]>,
47    /// The smallest contained inclusive packet number in the map
48    start: PacketNumber,
49    /// The largest contained inclusive packet number in the map
50    end: PacketNumber,
51    /// The starting index of the first occupied packet
52    ///
53    /// This field will be set to the `packets.len()` if the map is empty
54    index: usize,
55}
56
57impl<V: fmt::Debug> fmt::Debug for Map<V> {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        f.debug_map().entries(self.iter()).finish()
60    }
61}
62
63/// Start with 8 sent packets at a time
64///
65/// Capacity will grow exponentially as more packet number entries are added
66const DEFAULT_CAPACITY: usize = 8;
67
68impl<V> Default for Map<V> {
69    fn default() -> Self {
70        // we use the Initial packet number space as a filler until an actual
71        // packet number is inserted
72        let base = PacketNumberSpace::Initial.new_packet_number(0u8.into());
73
74        let mut values = Vec::with_capacity(DEFAULT_CAPACITY);
75        while values.len() != values.capacity() {
76            values.push(None);
77        }
78        let values = values.into_boxed_slice();
79
80        // Set the index to the len (OOB) to indicate that it's empty
81        let index = values.len();
82
83        Self {
84            values,
85            start: base,
86            end: base,
87            index,
88        }
89    }
90}
91
92impl<V> Map<V> {
93    /// Inserts the given `value`
94    pub fn insert(&mut self, packet_number: PacketNumber, value: V) {
95        if self.is_empty() {
96            self.start = packet_number;
97            self.end = packet_number;
98            self.values[0] = Some(value);
99            self.index = 0;
100            return;
101        }
102
103        // The implementation assumes monotonicity of insertion
104        debug_assert!(
105            packet_number > self.start && packet_number > self.end,
106            "packet numbers should be monotonic: {:?} > {:?} && {:?}",
107            packet_number,
108            self.start,
109            self.end
110        );
111
112        // check if we need to increase capacity
113        let distance = (packet_number.as_u64() - self.start.as_u64()) as usize;
114
115        let index = if distance >= self.values.len() {
116            self.resize(distance);
117
118            // use the distance as the index since we've already resized beyond it
119            distance
120        } else {
121            (self.index + distance) % self.values.len()
122        };
123
124        self.values[index] = Some(value);
125        self.end = packet_number;
126    }
127
128    /// Inserts the given `value` into the map or updates the existing entry
129    pub fn insert_or_update<F: FnOnce(&mut V)>(
130        &mut self,
131        packet_number: PacketNumber,
132        value: V,
133        update: F,
134    ) {
135        if self.is_empty() {
136            self.start = packet_number;
137            self.end = packet_number;
138            self.values[0] = Some(value);
139            self.index = 0;
140            return;
141        }
142
143        // The implementation assumes insertion is not lower than the start
144        debug_assert!(
145            packet_number >= self.start,
146            "packet numbers should be monotonic: {:?} > {:?}",
147            packet_number,
148            self.start,
149        );
150
151        // check if we need to increase capacity
152        let distance = (packet_number.as_u64() - self.start.as_u64()) as usize;
153
154        let index = if distance >= self.values.len() {
155            self.resize(distance);
156
157            // use the distance as the index since we've already resized beyond it
158            distance
159        } else {
160            (self.index + distance) % self.values.len()
161        };
162
163        let entry = &mut self.values[index];
164
165        if let Some(prev) = entry.as_mut() {
166            update(prev);
167        } else {
168            *entry = Some(value);
169        }
170
171        self.end = self.end.max(packet_number);
172    }
173
174    /// Returns a reference to the `V` associated with the given `packet_number`
175    #[inline]
176    pub fn get(&self, packet_number: PacketNumber) -> Option<&V> {
177        let index = self.pn_index(packet_number)?;
178        self.values[index].as_ref()
179    }
180
181    /// Removes the value associated with the given `packet_number`
182    /// and returns the value if it was present
183    pub fn remove(&mut self, packet_number: PacketNumber) -> Option<V> {
184        let index = self.pn_index(packet_number)?;
185        let info = self.values[index].take()?;
186
187        // update the bounds
188        match (self.start == packet_number, self.end == packet_number) {
189            // the bounds are inclusive so the map is now empty, reset it
190            //              [_, _, _, 3]
191            // remove(3) => [_, _, _, _]
192            (true, true) => {
193                self.clear();
194            }
195            // the packet was removed from the front
196            //              [0, 1, _, 3, 4]
197            // remove(0) => [_, 1, _, 3, 4]
198            // remove(1) => [_, _, _, 3, 4]
199            // remove(3) => [_, _, _, _, 4]
200            (true, false) => {
201                self.set_start(packet_number.next().unwrap());
202            }
203            // the packet was removed from the back
204            //              [0, 1, _, 3, 4]
205            // remove(4) => [0, 1, _, 3, _]
206            // remove(3) => [0, 1, _, _, _]
207            // remove(1) => [0, _, _, _, _]
208            (false, true) => {
209                self.set_end(packet_number.prev().unwrap());
210            }
211            // the packet was removed from the middle
212            //              [0, 1, 2]
213            // remove(2) => [0, _, 2]
214            (false, false) => {
215                // nothing to do
216            }
217        }
218
219        Some(info)
220    }
221
222    /// Removes a range of packets from the map and returns their value
223    #[inline]
224    pub fn remove_range(&mut self, range: PacketNumberRange) -> RemoveIter<'_, V> {
225        RemoveIter::new(self, range)
226    }
227
228    /// Get the inclusive PacketNumberRange
229    #[inline]
230    pub fn get_range(&self) -> PacketNumberRange {
231        PacketNumberRange::new(self.start, self.end)
232    }
233
234    /// Gets an iterator over the sent packet entries, sorted by PacketNumber
235    #[inline]
236    pub fn iter(&self) -> Iter<'_, V> {
237        Iter::new(self)
238    }
239
240    /// Returns true if there are no entries
241    #[inline]
242    pub fn is_empty(&self) -> bool {
243        self.index == self.values.len()
244    }
245
246    /// Clears all of the packet information in the sent
247    #[inline]
248    pub fn clear(&mut self) {
249        self.index = self.values.len();
250    }
251
252    #[inline]
253    fn pn_index(&self, packet_number: PacketNumber) -> Option<usize> {
254        // the map is empty so there are no valid entries
255        if self.is_empty() {
256            return None;
257        }
258
259        // make sure it's within the inserted packet numbers
260        if packet_number > self.end {
261            return None;
262        }
263
264        let offset = packet_number.checked_distance(self.start)?;
265        let index = self.index.checked_add(offset as usize)?;
266        let index = index % self.values.len();
267        Some(index)
268    }
269
270    #[inline]
271    fn set_start(&mut self, packet_number: PacketNumber) {
272        // this function assumes we have at least one element
273        debug_assert!(!self.is_empty());
274        debug_assert!(packet_number >= self.start);
275        debug_assert!(packet_number <= self.end);
276
277        // find the next occupied slot
278        for packet_number in PacketNumberRange::new(packet_number, self.end) {
279            if self.get(packet_number).is_some() {
280                let index = self
281                    .pn_index(packet_number)
282                    .expect("packet should be in bounds");
283
284                self.index = index;
285                self.start = packet_number;
286                debug_assert!(self.start <= self.end);
287                debug_assert_eq!(self.pn_index(packet_number), Some(index));
288                return;
289            }
290        }
291
292        unreachable!("could not find an occupied entry; map should be empty");
293    }
294
295    #[inline]
296    fn set_end(&mut self, packet_number: PacketNumber) {
297        // this function assumes we have at least one element
298        debug_assert!(!self.is_empty());
299        debug_assert!(packet_number >= self.start);
300        debug_assert!(packet_number <= self.end);
301
302        // find the next occupied slot
303        for packet_number in PacketNumberRange::new(self.start, packet_number).rev() {
304            if self.get(packet_number).is_some() {
305                self.end = packet_number;
306                debug_assert!(self.start <= self.end);
307                return;
308            }
309        }
310
311        unreachable!("could not find an occupied entry; map should be empty");
312    }
313
314    fn resize(&mut self, len: usize) {
315        let mut new_len = self.values.len();
316
317        // grow capacity until we can fit the inserted PN
318        loop {
319            new_len *= 2;
320            if len < new_len {
321                break;
322            }
323        }
324
325        // allocate a new packet buffer and copy the previous values
326        let mut values = Vec::with_capacity(new_len);
327        // The packets are stored in a ring so we copy from the index
328        // to the end, then from the start to the index
329        values.extend(self.values[self.index..].iter_mut().map(|v| v.take()));
330        values.extend(self.values[..self.index].iter_mut().map(|v| v.take()));
331        while values.len() != values.capacity() {
332            values.push(None);
333        }
334
335        // reset the index to the beginning of the buffer
336        self.index = 0;
337        self.values = values.into_boxed_slice();
338    }
339}
340
341/// An iterator over all of the contained packet numbers
342///
343/// This iterator is optimized to reduce the amount of bounds checks being performed
344#[derive(Debug)]
345pub struct Iter<'a, V> {
346    packets: &'a Map<V>,
347    packet_number: Option<PacketNumber>,
348    index: usize,
349    remaining: usize,
350}
351
352impl<'a, V> Iter<'a, V> {
353    #[inline]
354    fn new(packets: &'a Map<V>) -> Self {
355        let start = packets.start;
356        let end = packets.end;
357        let index = packets.index;
358
359        let mut iter = Self {
360            packets,
361            packet_number: Some(start),
362            index,
363            // start with an empty iterator
364            remaining: 0,
365        };
366
367        // make sure we have at least one packet
368        if iter.packets.is_empty() {
369            return iter;
370        }
371
372        // set the number of remaining entries based on the bounded range
373        iter.remaining = (end.as_u64() - start.as_u64()) as usize;
374        // we always have at least 1 items since the range is inclusive
375        iter.remaining += 1;
376
377        debug_assert!(iter.remaining <= iter.packets.values.len());
378
379        iter
380    }
381}
382
383impl<'a, V> Iterator for Iter<'a, V> {
384    type Item = (PacketNumber, &'a V);
385
386    #[inline]
387    fn next(&mut self) -> Option<Self::Item> {
388        while self.remaining > 0 {
389            self.remaining -= 1;
390
391            let packet_number = self.packet_number?;
392            self.packet_number = packet_number.next();
393
394            let index = self.index;
395            self.index = (index + 1) % self.packets.values.len();
396
397            if let Some(info) = self.packets.values[index].as_ref() {
398                return Some((packet_number, info));
399            }
400        }
401
402        None
403    }
404}
405
406/// An iterator which removes a set of packet numbers in a range
407///
408/// This iterator is optimized to reduce the amount of bounds checks being performed
409#[derive(Debug)]
410pub struct RemoveIter<'a, V> {
411    packets: &'a mut Map<V>,
412    packet_number: Option<PacketNumber>,
413    index: usize,
414    remaining: usize,
415}
416
417impl<'a, V> RemoveIter<'a, V> {
418    #[inline]
419    fn new(packets: &'a mut Map<V>, range: PacketNumberRange) -> Self {
420        let mut start = packets.start;
421        let mut end = packets.end;
422
423        let index = packets.index;
424
425        let mut iter = Self {
426            packets,
427            packet_number: None,
428            index,
429            // start with an empty iterator
430            remaining: 0,
431        };
432
433        // make sure we have at least one packet
434        if iter.packets.is_empty() {
435            return iter;
436        }
437
438        // ensure the range overlaps with the contained items
439        if range.end() < start || range.start() > end {
440            return iter;
441        }
442
443        use core::cmp::Ordering::*;
444
445        match (range.start().cmp(&start), range.end().cmp(&end)) {
446            (Less, Equal) | (Less, Greater) | (Equal, Greater) | (Equal, Equal) => {
447                // deleting all entries
448
449                // clear the sent packets
450                //
451                // NOTE: this doesn't actually delete anything in the buffer
452                iter.packets.clear();
453
454                // no need to update index as it's already set to the lower bound
455            }
456            (Less, Less) | (Equal, Less) => {
457                // deleting start
458                end = range.end();
459
460                iter.packets.set_start(end.next().unwrap());
461            }
462            (Greater, Greater) | (Greater, Equal) => {
463                // deleting end
464                start = range.start();
465
466                iter.index = iter
467                    .packets
468                    .pn_index(start)
469                    .expect("packet number bounds have already been checked");
470
471                iter.packets.set_end(start.prev().unwrap());
472            }
473            (Greater, Less) => {
474                // deleting middle part
475                start = range.start();
476                end = range.end();
477
478                iter.index = iter
479                    .packets
480                    .pn_index(start)
481                    .expect("packet number bounds have already been checked");
482            }
483        }
484
485        // Update the starting packet number
486        iter.packet_number = Some(start);
487        // set the number of remaining entries based on the bounded range
488        iter.remaining = (end.as_u64() - start.as_u64()) as usize;
489        // we always have at least 1 items since the range is inclusive
490        iter.remaining += 1;
491
492        debug_assert!(iter.remaining <= iter.packets.values.len());
493
494        iter
495    }
496}
497
498impl<V> Iterator for RemoveIter<'_, V> {
499    type Item = (PacketNumber, V);
500
501    #[inline]
502    fn next(&mut self) -> Option<Self::Item> {
503        while self.remaining > 0 {
504            self.remaining -= 1;
505
506            let packet_number = self.packet_number?;
507            self.packet_number = packet_number.next();
508
509            let index = self.index;
510            self.index = (index + 1) % self.packets.values.len();
511
512            if let Some(info) = self.packets.values[index].take() {
513                return Some((packet_number, info));
514            }
515        }
516
517        None
518    }
519}
520
521impl<V> Drop for RemoveIter<'_, V> {
522    fn drop(&mut self) {
523        // make sure the iterator is drained, otherwise the entries might dangle
524        while self.next().is_some() {}
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::{
532        packet::number::{PacketNumber, PacketNumberRange, PacketNumberSpace},
533        varint::VarInt,
534    };
535    use alloc::collections::BTreeMap;
536    use bolero::{check, generator::*};
537
538    type TestMap = Map<u64>;
539
540    #[test]
541    fn insert_get_range() {
542        let mut sent_packets = TestMap::default();
543
544        let packet_number_1 = PacketNumberSpace::Initial.new_packet_number(VarInt::from_u8(1));
545        let packet_number_2 = packet_number_1.next().unwrap();
546        let packet_number_3 = packet_number_2.next().unwrap();
547
548        sent_packets.insert(packet_number_1, 1);
549        sent_packets.insert(packet_number_2, 2);
550
551        assert!(sent_packets.get(packet_number_1).is_some());
552        assert!(sent_packets.get(packet_number_2).is_some());
553        assert!(sent_packets.get(packet_number_3).is_none());
554
555        assert_eq!(sent_packets.get(packet_number_1).unwrap(), &1);
556        assert_eq!(sent_packets.get(packet_number_2).unwrap(), &2);
557
558        sent_packets.insert(packet_number_3, 3);
559
560        assert!(sent_packets.get(packet_number_3).is_some());
561        assert_eq!(sent_packets.get(packet_number_3).unwrap(), &3);
562
563        for (packet_number, sent_packet_info) in sent_packets.iter() {
564            assert_eq!(sent_packets.get(packet_number).unwrap(), sent_packet_info);
565        }
566    }
567
568    #[test]
569    fn remove() {
570        let mut sent_packets = TestMap::default();
571        let packet_number = PacketNumberSpace::Initial.new_packet_number(VarInt::from_u8(1));
572        sent_packets.insert(packet_number, 1);
573
574        assert!(sent_packets.get(packet_number).is_some());
575        assert_eq!(sent_packets.get(packet_number).unwrap(), &1);
576
577        assert_eq!(Some(1), sent_packets.remove(packet_number));
578
579        assert!(sent_packets.get(packet_number).is_none());
580
581        // Removing a packet that was already removed doesn't panic
582        assert_eq!(None, sent_packets.remove(packet_number));
583    }
584
585    #[test]
586    fn empty() {
587        let mut sent_packets = TestMap::default();
588        assert!(sent_packets.is_empty());
589
590        let packet_number = PacketNumberSpace::Initial.new_packet_number(VarInt::from_u8(1));
591        sent_packets.insert(packet_number, 1);
592        assert!(!sent_packets.is_empty());
593    }
594
595    #[test]
596    #[should_panic]
597    fn wrong_packet_space_on_insert() {
598        let mut sent_packets = new_sent_packets(PacketNumberSpace::Initial);
599
600        let packet_number =
601            PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
602        sent_packets.insert(packet_number, 1);
603    }
604
605    #[test]
606    #[should_panic]
607    fn wrong_packet_space_on_get() {
608        let sent_packets = new_sent_packets(PacketNumberSpace::Initial);
609
610        let packet_number =
611            PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
612        sent_packets.get(packet_number);
613    }
614
615    #[test]
616    #[should_panic]
617    fn wrong_packet_space_on_remove_range() {
618        let mut sent_packets = new_sent_packets(PacketNumberSpace::Initial);
619
620        let packet_number_start =
621            PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
622        let packet_number_end =
623            PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(2));
624        sent_packets
625            .remove_range(PacketNumberRange::new(
626                packet_number_start,
627                packet_number_end,
628            ))
629            .for_each(|_| ());
630    }
631
632    #[test]
633    #[should_panic]
634    fn wrong_packet_space_on_remove() {
635        let mut sent_packets = new_sent_packets(PacketNumberSpace::Initial);
636
637        let packet_number =
638            PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
639        sent_packets.remove(packet_number);
640    }
641
642    fn new_sent_packets(space: PacketNumberSpace) -> TestMap {
643        let mut sent_packets = TestMap::default();
644        let packet_number = space.new_packet_number(VarInt::from_u8(0));
645        sent_packets.insert(packet_number, 0);
646        sent_packets
647    }
648
649    /// An operation to be performed against a model
650    #[derive(Clone, Copy, Debug, TypeGenerator)]
651    enum Operation {
652        // Inserts the current packet number
653        Insert,
654        // Skips the packet number
655        Skip,
656        // Removes a packet number
657        Remove(VarInt),
658        // Removes a range of packet numbers
659        RemoveRange(VarInt, VarInt),
660    }
661
662    fn model(ops: &[Operation]) {
663        let mut current = PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(0));
664
665        /// Tracks the subject against an oracle to ensure differential equivalency
666        #[derive(Debug, Default)]
667        struct Model {
668            subject: TestMap,
669            oracle: BTreeMap<PacketNumber, u64>,
670        }
671
672        impl Model {
673            pub fn insert(&mut self, packet_number: PacketNumber) {
674                let value = packet_number.as_u64();
675
676                self.subject.insert(packet_number, value);
677                self.oracle.insert(packet_number, value);
678                self.check_consistency();
679            }
680
681            pub fn remove(&mut self, packet_number: PacketNumber) {
682                assert_eq!(
683                    self.subject.remove(packet_number),
684                    self.oracle.remove(&packet_number)
685                );
686                self.check_consistency();
687            }
688
689            pub fn remove_range(&mut self, range: PacketNumberRange) {
690                // trim range so we're not slamming the BTreeMap
691                let range = if self.subject.is_empty() {
692                    PacketNumberRange::new(range.start(), range.start())
693                } else {
694                    let start = range.start().max(self.subject.start);
695                    let end = range.end().min(self.subject.end);
696                    if start > end {
697                        PacketNumberRange::new(start, start)
698                    } else {
699                        PacketNumberRange::new(start, end)
700                    }
701                };
702
703                let actual: Vec<_> = self.subject.remove_range(range).collect();
704                let mut expected = vec![];
705
706                for pn in range {
707                    if let Some(value) = self.oracle.remove(&pn) {
708                        expected.push((pn, value));
709                    }
710                }
711
712                assert_eq!(expected, actual);
713
714                self.check_consistency();
715            }
716
717            fn check_consistency(&self) {
718                let mut subject = self.subject.iter();
719                let mut oracle = self.oracle.iter();
720                loop {
721                    match (subject.next(), oracle.next()) {
722                        (Some(actual), Some((expected_pn, expected_info))) => {
723                            assert_eq!((*expected_pn, expected_info), actual);
724                        }
725                        (None, None) => break,
726                        (actual, expected) => {
727                            panic!("expected: {expected:?}, actual: {actual:?}");
728                        }
729                    }
730                }
731            }
732        }
733
734        let mut model = Model::default();
735
736        for op in ops.iter().copied() {
737            match op {
738                Operation::Insert => {
739                    model.insert(current);
740                    current = current.next().unwrap();
741                }
742                Operation::Skip => {
743                    current = current.next().unwrap();
744                }
745                Operation::Remove(pn) => {
746                    let pn = PacketNumberSpace::ApplicationData.new_packet_number(pn);
747
748                    model.remove(pn);
749                }
750                Operation::RemoveRange(start, end) => {
751                    let (start, end) = if start > end {
752                        (end, start)
753                    } else {
754                        (start, end)
755                    };
756                    let start = PacketNumberSpace::ApplicationData.new_packet_number(start);
757                    let end = PacketNumberSpace::ApplicationData.new_packet_number(end);
758                    let range = PacketNumberRange::new(start, end);
759
760                    model.remove_range(range);
761                }
762            }
763        }
764    }
765
766    #[test]
767    fn differential_test() {
768        check!()
769            .with_type::<Vec<Operation>>()
770            .for_each(|ops| model(ops))
771    }
772
773    #[test]
774    #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
775    fn insert_value() {
776        // Confirm that a value is inserted
777        check!().with_type().cloned().for_each(|pn| {
778            let space = PacketNumberSpace::ApplicationData;
779            let mut map = Map::default();
780            assert!(map.is_empty());
781            let pn = space.new_packet_number(pn);
782
783            map.insert(pn, ());
784
785            assert!(map.get(pn).is_some());
786            assert!(!map.is_empty());
787        });
788    }
789}