s2n_quic_core/
counter.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::number::{
5    CheckedAddAssign, CheckedMulAssign, CheckedSubAssign, SaturatingAddAssign, SaturatingMulAssign,
6    SaturatingSubAssign, UpcastFrom,
7};
8use core::{cmp::Ordering, marker::PhantomData, ops};
9
10/// A checked-overflow counter
11///
12/// Rather than silently wrapping, we want to ensure counting errors stay somewhat isolated so the
13/// counter will saturate rather than wrap. The counter operates in 3 modes:
14///
15/// * If `debug_assertions` are enabled, the counter will panic on overflow
16/// * If the `checked-counters` feature flag is defined, the counter will panic on overflow, even in
17///   release builds.
18/// * Otherwise, the counter will saturate
19///
20/// The counter can also be configured to always saturate by passing the `Saturating` behavior:
21///
22/// ```rust
23/// use s2n_quic_core::counter::{Counter, Saturating};
24///
25/// let counter: Counter<u32, Saturating> = Default::default();
26/// ```
27#[derive(Clone, Copy, Debug, Default, Hash)]
28pub struct Counter<T, Behavior = ()>(T, PhantomData<Behavior>);
29
30/// Overrides the behavior of a counter to always saturate
31#[derive(Clone, Copy, Debug, Default, Hash)]
32pub struct Saturating;
33
34impl<T, Behavior> Counter<T, Behavior> {
35    /// Creates a new counter with an initial value
36    #[inline]
37    pub const fn new(value: T) -> Self {
38        Self(value, PhantomData)
39    }
40
41    #[inline]
42    pub fn set(&mut self, value: T) {
43        self.0 = value;
44    }
45
46    /// Tries to convert V to T and add it to the current counter value
47    #[inline]
48    pub fn try_add<V>(&mut self, value: V) -> Result<(), T::Error>
49    where
50        T: TryFrom<V>,
51        Self: ops::AddAssign<T>,
52    {
53        let value = T::try_from(value)?;
54        *self += value;
55        Ok(())
56    }
57
58    /// Tries to convert V to T and subtract it from the current counter value
59    #[inline]
60    pub fn try_sub<V>(&mut self, value: V) -> Result<(), T::Error>
61    where
62        T: TryFrom<V>,
63        Self: ops::SubAssign<T>,
64    {
65        let value = T::try_from(value)?;
66        *self -= value;
67        Ok(())
68    }
69}
70
71/// Generates an assign trait implementation for the Counter
72macro_rules! assign_trait {
73    (
74        $op:ident,
75        $method:ident,
76        $saturating_trait:ident,
77        $saturating_method:ident,
78        $checked_trait:ident,
79        $checked:ident
80    ) => {
81        impl<T, R> ops::$op<R> for Counter<T, ()>
82        where
83            T: $saturating_trait<R> + $checked_trait<R> + ops::$op + UpcastFrom<R>,
84        {
85            #[inline]
86            fn $method(&mut self, rhs: R) {
87                if cfg!(feature = "checked-counters") {
88                    (self.0).$checked(rhs).expect("counter overflow");
89                } else if cfg!(debug_assertions) {
90                    (self.0).$method(T::upcast_from(rhs));
91                } else {
92                    (self.0).$saturating_method(rhs);
93                }
94            }
95        }
96
97        impl<T, R> ops::$op<R> for Counter<T, Saturating>
98        where
99            T: $saturating_trait<R>,
100        {
101            #[inline]
102            fn $method(&mut self, rhs: R) {
103                (self.0).$saturating_method(rhs);
104            }
105        }
106    };
107}
108
109assign_trait!(
110    AddAssign,
111    add_assign,
112    SaturatingAddAssign,
113    saturating_add_assign,
114    CheckedAddAssign,
115    checked_add_assign
116);
117
118assign_trait!(
119    SubAssign,
120    sub_assign,
121    SaturatingSubAssign,
122    saturating_sub_assign,
123    CheckedSubAssign,
124    checked_sub_assign
125);
126
127assign_trait!(
128    MulAssign,
129    mul_assign,
130    SaturatingMulAssign,
131    saturating_mul_assign,
132    CheckedMulAssign,
133    checked_mul_assign
134);
135
136impl<T, B> UpcastFrom<Counter<T, B>> for T {
137    #[inline]
138    fn upcast_from(value: Counter<T, B>) -> Self {
139        value.0
140    }
141}
142
143impl<T, B> UpcastFrom<&Counter<T, B>> for T
144where
145    T: for<'a> UpcastFrom<&'a T>,
146{
147    #[inline]
148    fn upcast_from(value: &Counter<T, B>) -> Self {
149        T::upcast_from(&value.0)
150    }
151}
152
153impl<T, B> ops::Deref for Counter<T, B> {
154    type Target = T;
155
156    #[inline]
157    fn deref(&self) -> &Self::Target {
158        &self.0
159    }
160}
161
162impl<T, B, R> PartialEq<R> for Counter<T, B>
163where
164    Self: PartialOrd<R>,
165{
166    #[inline]
167    fn eq(&self, other: &R) -> bool {
168        self.partial_cmp(other) == Some(Ordering::Equal)
169    }
170}
171
172impl<T, B> PartialOrd<T> for Counter<T, B>
173where
174    T: PartialOrd<T>,
175{
176    #[inline]
177    fn partial_cmp(&self, other: &T) -> Option<Ordering> {
178        self.0.partial_cmp(other)
179    }
180}
181
182impl<T, B> PartialOrd for Counter<T, B>
183where
184    T: PartialOrd<T>,
185{
186    #[inline]
187    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
188        self.0.partial_cmp(&other.0)
189    }
190}
191
192impl<T, B> Eq for Counter<T, B> where Self: Ord {}
193
194impl<T, B> Ord for Counter<T, B>
195where
196    T: Ord,
197{
198    #[inline]
199    fn cmp(&self, other: &Self) -> Ordering {
200        self.0.cmp(&other.0)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn automatic_upcast() {
210        let mut a: Counter<u32> = Counter::new(0);
211        a += 1u8;
212        a += 2u16;
213        a += 3u32;
214
215        assert_eq!(a, Counter::new(6));
216        assert_eq!(a, 6u32);
217    }
218
219    #[test]
220    fn saturating() {
221        let mut a: Counter<u8, Saturating> = Counter::new(0);
222        a += 250;
223        a += 250;
224        a += 123;
225
226        assert_eq!(a, Counter::new(255));
227    }
228}