1use core::{cell::Cell, fmt};
5
6#[derive(Clone)]
14pub struct Memo<T: Copy, Input, Check = DefaultConsistencyCheck> {
15 value: Cell<Option<T>>,
16 query: fn(&Input) -> T,
17 check: Check,
18}
19
20impl<T: Copy + fmt::Debug, Input, Check> fmt::Debug for Memo<T, Input, Check> {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 f.debug_tuple("Memo").field(&self.value.get()).finish()
23 }
24}
25
26impl<T: Copy + PartialEq + fmt::Debug, Input, Check: ConsistencyCheck> Memo<T, Input, Check> {
27 #[inline]
29 pub fn new(query: fn(&Input) -> T) -> Self {
30 Self {
31 value: Cell::new(None),
32 query,
33 check: Check::default(),
34 }
35 }
36
37 #[inline]
39 #[track_caller]
40 pub fn get(&self, input: &Input) -> T {
41 if let Some(value) = self.value.get() {
42 self.check.check_consistency(value, input, self.query);
44 return value;
45 }
46
47 let value = (self.query)(input);
48 self.value.set(Some(value));
49 value
50 }
51
52 #[inline]
54 pub fn clear(&self) {
55 self.value.set(None);
56 }
57
58 #[inline]
60 #[track_caller]
61 pub fn check_consistency(&self, input: &Input) {
62 if cfg!(debug_assertions) {
63 let _ = self.get(input);
65 }
66 }
67}
68
69pub trait ConsistencyCheck: Clone + Default {
71 fn check_consistency<T: PartialEq + fmt::Debug, Input>(
75 &self,
76 cache: T,
77 input: &Input,
78 query: fn(&Input) -> T,
79 );
80}
81
82#[derive(Copy, Clone, Default)]
83pub struct ConsistencyCheckAlways;
84
85impl ConsistencyCheck for ConsistencyCheckAlways {
86 #[inline]
87 fn check_consistency<T: PartialEq + fmt::Debug, Input>(
88 &self,
89 actual: T,
90 input: &Input,
91 query: fn(&Input) -> T,
92 ) {
93 let expected = query(input);
94 assert_eq!(expected, actual);
95 }
96}
97
98#[derive(Copy, Clone, Default)]
99pub struct ConsistencyCheckNever;
100
101impl ConsistencyCheck for ConsistencyCheckNever {
102 #[inline]
103 fn check_consistency<T: PartialEq + fmt::Debug, Input>(
104 &self,
105 _cache: T,
106 _input: &Input,
107 _query: fn(&Input) -> T,
108 ) {
109 }
111}
112
113#[cfg(debug_assertions)]
114pub type DefaultConsistencyCheck = ConsistencyCheckAlways;
115#[cfg(not(debug_assertions))]
116pub type DefaultConsistencyCheck = ConsistencyCheckNever;
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[derive(Debug, Default)]
123 struct Input<Value> {
124 value: Value,
125 should_query: bool,
126 }
127
128 #[test]
129 fn memo_test() {
130 let memo = Memo::<u64, Input<_>, ConsistencyCheckNever>::new(|input| {
131 assert!(
132 input.should_query,
133 "query was called when it wasn't expected"
134 );
135 input.value
136 });
137
138 assert_eq!(
139 memo.get(&Input {
140 value: 1,
141 should_query: true,
142 }),
143 1
144 );
145
146 assert_eq!(
147 memo.get(&Input {
148 value: 2,
149 should_query: false,
150 }),
151 1
152 );
153
154 memo.clear();
155
156 assert_eq!(
157 memo.get(&Input {
158 value: 3,
159 should_query: true,
160 }),
161 3
162 );
163
164 assert_eq!(
165 memo.get(&Input {
166 value: 4,
167 should_query: false,
168 }),
169 3
170 );
171 }
172}