s2n_quic_core/
memo.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{cell::Cell, fmt};
5
6/// A datastructure that [memoizes](https://wikipedia.org/wiki/Memoization) a query function
7///
8/// This can be used for when queries rarely change and can potentially be expensive or on hot
9/// code paths. After the `input` is mutated, the query value should be `clear`ed to signal that
10/// the function needs to be executed again.
11///
12/// In debug mode the `get` call will always run the query and assert that the values match.
13#[derive(Clone)]
14pub struct Memo<T: Copy, Input, Check = DefaultConsistencyCheck> {
15    value: Cell<Option<T>>,
16    query: fn(&Input) -> T,
17    check: Check,
18}
19
20impl<T: Copy + fmt::Debug, Input, Check> fmt::Debug for Memo<T, Input, Check> {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        f.debug_tuple("Memo").field(&self.value.get()).finish()
23    }
24}
25
26impl<T: Copy + PartialEq + fmt::Debug, Input, Check: ConsistencyCheck> Memo<T, Input, Check> {
27    /// Creates a new `Memo` over a query function
28    #[inline]
29    pub fn new(query: fn(&Input) -> T) -> Self {
30        Self {
31            value: Cell::new(None),
32            query,
33            check: Check::default(),
34        }
35    }
36
37    /// Returns the current value of the query function, which may be cached
38    #[inline]
39    #[track_caller]
40    pub fn get(&self, input: &Input) -> T {
41        if let Some(value) = self.value.get() {
42            // make sure the values match
43            self.check.check_consistency(value, input, self.query);
44            return value;
45        }
46
47        let value = (self.query)(input);
48        self.value.set(Some(value));
49        value
50    }
51
52    /// Clears the cached value of the query function
53    #[inline]
54    pub fn clear(&self) {
55        self.value.set(None);
56    }
57
58    /// Asserts that the cached value reflects the current query result in debug mode
59    #[inline]
60    #[track_caller]
61    pub fn check_consistency(&self, input: &Input) {
62        if cfg!(debug_assertions) {
63            // `get` will assert the value matches the query internally
64            let _ = self.get(input);
65        }
66    }
67}
68
69/// Trait to configure consistency checking behavior
70pub trait ConsistencyCheck: Clone + Default {
71    /// Called when the `Memo` struct has a cached value
72    ///
73    /// An implementation can assert that the `cache` value matches the current `query` result
74    fn check_consistency<T: PartialEq + fmt::Debug, Input>(
75        &self,
76        cache: T,
77        input: &Input,
78        query: fn(&Input) -> T,
79    );
80}
81
82#[derive(Copy, Clone, Default)]
83pub struct ConsistencyCheckAlways;
84
85impl ConsistencyCheck for ConsistencyCheckAlways {
86    #[inline]
87    fn check_consistency<T: PartialEq + fmt::Debug, Input>(
88        &self,
89        actual: T,
90        input: &Input,
91        query: fn(&Input) -> T,
92    ) {
93        let expected = query(input);
94        assert_eq!(expected, actual);
95    }
96}
97
98#[derive(Copy, Clone, Default)]
99pub struct ConsistencyCheckNever;
100
101impl ConsistencyCheck for ConsistencyCheckNever {
102    #[inline]
103    fn check_consistency<T: PartialEq + fmt::Debug, Input>(
104        &self,
105        _cache: T,
106        _input: &Input,
107        _query: fn(&Input) -> T,
108    ) {
109        // noop
110    }
111}
112
113#[cfg(debug_assertions)]
114pub type DefaultConsistencyCheck = ConsistencyCheckAlways;
115#[cfg(not(debug_assertions))]
116pub type DefaultConsistencyCheck = ConsistencyCheckNever;
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[derive(Debug, Default)]
123    struct Input<Value> {
124        value: Value,
125        should_query: bool,
126    }
127
128    #[test]
129    fn memo_test() {
130        let memo = Memo::<u64, Input<_>, ConsistencyCheckNever>::new(|input| {
131            assert!(
132                input.should_query,
133                "query was called when it wasn't expected"
134            );
135            input.value
136        });
137
138        assert_eq!(
139            memo.get(&Input {
140                value: 1,
141                should_query: true,
142            }),
143            1
144        );
145
146        assert_eq!(
147            memo.get(&Input {
148                value: 2,
149                should_query: false,
150            }),
151            1
152        );
153
154        memo.clear();
155
156        assert_eq!(
157            memo.get(&Input {
158                value: 3,
159                should_query: true,
160            }),
161            3
162        );
163
164        assert_eq!(
165            memo.get(&Input {
166                value: 4,
167                should_query: false,
168            }),
169            3
170        );
171    }
172}