1use std::cmp::max;
2
3use crate::RefCounter;
4
5pub enum AVL<K, V = ()> {
6 Empty,
7 Node {
8 key: RefCounter<K>,
9 value: RefCounter<V>,
10 left: RefCounter<AVL<K, V>>,
11 right: RefCounter<AVL<K, V>>,
12 },
13}
14
15pub type OrderedMap<K, V> = AVL<K, V>;
16pub type OrderedSet<K> = AVL<K>;
17
18impl<K, V> Clone for AVL<K, V> {
19 fn clone(&self) -> Self {
20 match self {
21 Self::Empty => Self::Empty,
22 Self::Node {
23 key,
24 value,
25 left,
26 right,
27 } => Self::Node {
28 key: key.clone(),
29 value: value.clone(),
30 left: left.clone(),
31 right: right.clone(),
32 },
33 }
34 }
35}
36
37impl<K: Ord> AVL<K> {
38 pub fn insert(&self, value: K) -> Self {
39 self.put(value, ())
40 }
41 pub fn search(&self, value: &K) -> bool {
42 self.find(value).is_some()
43 }
44}
45
46impl<K: Ord, V> AVL<K, V> {
47 pub fn empty() -> AVL<K, V> {
48 AVL::Empty
49 }
50 pub fn is_empty(&self) -> bool {
51 matches!(self, AVL::Empty)
52 }
53 fn height(&self) -> i64 {
54 match self {
55 AVL::Empty => 0,
56 AVL::Node {
57 key: _,
58 value: _,
59 left,
60 right,
61 } => 1 + max(&left.height(), &right.height()),
62 }
63 }
64 fn diff(&self) -> i64 {
65 match self {
66 AVL::Empty => 0,
67 AVL::Node {
68 key: _,
69 value: _,
70 left,
71 right,
72 } => left.height() - right.height(),
73 }
74 }
75 pub fn find(&self, target_value: &K) -> Option<&V> {
76 match self {
77 AVL::Empty => Option::None,
78 AVL::Node {
79 key,
80 value,
81 left,
82 right,
83 } => match target_value.cmp(key) {
84 std::cmp::Ordering::Less => left.find(target_value),
85 std::cmp::Ordering::Equal => Option::Some(value.as_ref()),
86 std::cmp::Ordering::Greater => right.find(target_value),
87 },
88 }
89 }
90 fn right_rotation(&self) -> AVL<K, V> {
91 if let AVL::Node {
92 key: x,
93 value: vx,
94 left: lt,
95 right: t3,
96 } = self
97 {
98 if let AVL::Node {
99 key: y,
100 value: vy,
101 left: t1,
102 right: t2,
103 } = (*lt).as_ref()
104 {
105 return AVL::Node {
106 key: y.clone(),
107 left: t1.clone(),
108 value: vy.clone(),
109 right: RefCounter::new(AVL::Node {
110 key: x.clone(),
111 value: vx.clone(),
112 left: t2.clone(),
113 right: t3.clone(),
114 }),
115 };
116 }
117 }
118 self.clone()
119 }
120 fn right_fix(&self) -> AVL<K, V> {
121 if let AVL::Node {
122 key: x,
123 value: vx,
124 left: t1,
125 right: t2,
126 } = self
127 {
128 if t1.diff() == -1 {
129 return AVL::Node {
130 key: x.clone(),
131 value: vx.clone(),
132 left: RefCounter::new(t1.left_rotation()),
133 right: t2.clone(),
134 }
135 .right_rotation();
136 } else {
137 return self.right_rotation();
138 }
139 }
140 self.clone()
141 }
142 fn left_rotation(&self) -> AVL<K, V> {
143 if let AVL::Node {
144 key: x,
145 value: vx,
146 left: t1,
147 right: rt,
148 } = self
149 {
150 if let AVL::Node {
151 key: y,
152 value: vy,
153 left: t2,
154 right: t3,
155 } = (*rt).as_ref()
156 {
157 return AVL::Node {
158 key: y.clone(),
159 value: vy.clone(),
160 left: RefCounter::new(AVL::Node {
161 key: x.clone(),
162 value: vx.clone(),
163 left: t1.clone(),
164 right: t2.clone(),
165 }),
166 right: t3.clone(),
167 };
168 }
169 }
170 self.clone()
171 }
172 fn left_fix(&self) -> AVL<K, V> {
173 if let AVL::Node {
174 key: x,
175 value: vx,
176 left: t1,
177 right: t2,
178 } = self
179 {
180 if t2.diff() == 1 {
181 return AVL::Node {
182 key: x.clone(),
183 value: vx.clone(),
184 left: t1.clone(),
185 right: RefCounter::new(t2.right_rotation()),
186 }
187 .left_rotation();
188 } else {
189 return self.left_rotation();
190 }
191 }
192 self.clone()
193 }
194 fn fix(&self) -> AVL<K, V> {
195 match self.diff() {
196 2 => self.right_fix(),
197 -2 => self.left_fix(),
198 _ => self.clone(),
199 }
200 }
201 pub fn put(&self, key: K, value: V) -> AVL<K, V> {
202 self.put_rc(RefCounter::new(key), RefCounter::new(value))
203 }
204 fn put_rc(&self, key_rc: RefCounter<K>, value_rc: RefCounter<V>) -> AVL<K, V> {
205 match self {
206 AVL::Empty => AVL::Node {
207 key: key_rc,
208 value: value_rc,
209 left: RefCounter::new(AVL::Empty),
210 right: RefCounter::new(AVL::Empty),
211 },
212 AVL::Node {
213 key,
214 value,
215 left,
216 right,
217 } => match key_rc.cmp(key) {
218 std::cmp::Ordering::Less => AVL::Node {
219 key: key.clone(),
220 value: value.clone(),
221 left: RefCounter::new(left.put_rc(key_rc, value_rc)),
222 right: right.clone(),
223 }
224 .fix(),
225 std::cmp::Ordering::Equal => AVL::Node {
226 key: key_rc,
227 value: value_rc,
228 left: left.clone(),
229 right: right.clone(),
230 },
231 std::cmp::Ordering::Greater => AVL::Node {
232 key: key.clone(),
233 value: value.clone(),
234 left: left.clone(),
235 right: RefCounter::new(right.put_rc(key_rc, value_rc)),
236 }
237 .fix(),
238 },
239 }
240 }
241 pub fn delete(&self, target_key: &K) -> AVL<K, V> {
242 match self {
243 AVL::Empty => AVL::Empty,
244 AVL::Node {
245 key,
246 value,
247 left,
248 right,
249 } => {
250 match target_key.cmp(key) {
251 std::cmp::Ordering::Less => {
252 let left_deleted = left.delete(target_key);
253 AVL::Node {
254 key: key.clone(),
255 value: value.clone(),
256 left: RefCounter::new(left_deleted),
257 right: right.clone(),
258 }
259 .fix()
260 }
261 std::cmp::Ordering::Equal => {
262 if left.is_empty() {
264 return right.as_ref().clone();
265 } else if right.is_empty() {
266 return left.as_ref().clone();
267 }
268
269 let inorder_predecessor = left.find_max();
271 if let Some((pred_key, pred_value)) = inorder_predecessor {
272 let left_deleted = left.delete(&pred_key);
273 AVL::Node {
274 key: pred_key.clone(),
275 value: pred_value.clone(),
276 left: RefCounter::new(left_deleted),
277 right: right.clone(),
278 }
279 .fix()
280 } else {
281 self.clone()
282 }
283 }
284 std::cmp::Ordering::Greater => {
285 let right_deleted = right.delete(target_key);
286 AVL::Node {
287 key: key.clone(),
288 value: value.clone(),
289 left: left.clone(),
290 right: RefCounter::new(right_deleted),
291 }
292 .fix()
293 }
294 }
295 }
296 }
297 }
298
299 fn find_max(&self) -> Option<(RefCounter<K>, RefCounter<V>)> {
300 match self {
301 AVL::Empty => None,
302 AVL::Node {
303 key,
304 value,
305 left: _,
306 right,
307 } => {
308 if right.is_empty() {
309 Some((key.clone(), value.clone()))
310 } else {
311 right.find_max()
312 }
313 }
314 }
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_avl_set() {
324 let l = AVL::empty().insert(1).insert(2).insert(3).insert(4);
325 let l2 = l.clone().insert(5);
326 for i in 1..=4 {
327 assert!(l.search(&i));
328 assert!(l2.search(&i));
329 }
330 assert!(!l.search(&5));
331 assert!(l2.search(&5));
332 }
333
334 #[test]
335 fn test_avl_map() {
336 let l = AVL::empty().put(1, 999);
337 let l2 = l.clone().put(1, 123).put(2, 3);
338 assert_eq!(l.find(&1), Some(&999));
339 assert_eq!(l2.find(&1), Some(&123));
340 assert!(l.find(&2).is_none());
341 assert!(l2.find(&2).is_some());
342 }
343
344 #[test]
345 fn test_avl_delete() {
346 let l = AVL::empty()
347 .insert(1)
348 .insert(2)
349 .insert(3)
350 .insert(4)
351 .insert(5);
352 let l = l.delete(&3);
353 assert!(!l.search(&3));
354 assert!(l.search(&1));
355 assert!(l.search(&2));
356 assert!(l.search(&4));
357 assert!(l.search(&5));
358 }
359}