1use std::collections::HashMap;
23use crate::error::{MinCutError, Result};
24
25pub type NodeId = u64;
27
28#[derive(Debug, Clone)]
30struct SplayNode {
31 id: NodeId,
33 parent: Option<usize>,
35 left: Option<usize>,
37 right: Option<usize>,
39 path_parent: Option<usize>,
41 size: usize,
43 value: f64,
45 path_aggregate: f64,
47 reversed: bool,
49}
50
51impl SplayNode {
52 #[inline]
54 fn new(id: NodeId, value: f64) -> Self {
55 Self {
56 id,
57 parent: None,
58 left: None,
59 right: None,
60 path_parent: None,
61 size: 1,
62 value,
63 path_aggregate: value,
64 reversed: false,
65 }
66 }
67
68 #[inline(always)]
73 fn is_root(&self, nodes: &[SplayNode]) -> bool {
74 if let Some(p) = self.parent {
75 let parent = &nodes[p];
76 parent.left != Some(self.id as usize) && parent.right != Some(self.id as usize)
77 } else {
78 true
79 }
80 }
81}
82
83pub struct LinkCutTree {
85 nodes: Vec<SplayNode>,
87 id_to_index: HashMap<NodeId, usize>,
89 index_to_id: Vec<NodeId>,
91 root_cache: HashMap<usize, usize>,
94}
95
96impl LinkCutTree {
97 #[inline]
99 pub fn new() -> Self {
100 Self {
101 nodes: Vec::new(),
102 id_to_index: HashMap::new(),
103 index_to_id: Vec::new(),
104 root_cache: HashMap::new(),
105 }
106 }
107
108 #[inline]
113 pub fn with_capacity(n: usize) -> Self {
114 Self {
115 nodes: Vec::with_capacity(n),
116 id_to_index: HashMap::with_capacity(n),
117 index_to_id: Vec::with_capacity(n),
118 root_cache: HashMap::with_capacity(n / 4), }
120 }
121
122 #[inline]
124 pub fn make_tree(&mut self, id: NodeId, value: f64) -> usize {
125 let index = self.nodes.len();
126 self.nodes.push(SplayNode::new(id, value));
127 self.id_to_index.insert(id, index);
128 self.index_to_id.push(id);
129 index
130 }
131
132 #[inline]
137 fn get_index(&self, id: NodeId) -> Result<usize> {
138 self.id_to_index
139 .get(&id)
140 .copied()
141 .ok_or_else(|| self.invalid_vertex_error(id))
142 }
143
144 #[cold]
145 #[inline(never)]
146 fn invalid_vertex_error(&self, id: NodeId) -> MinCutError {
147 MinCutError::InvalidVertex(id)
148 }
149
150 pub fn link(&mut self, u: NodeId, v: NodeId) -> Result<()> {
154 let u_idx = self.get_index(u)?;
155 let v_idx = self.get_index(v)?;
156
157 if self.connected(u, v) {
159 return Err(self.already_connected_error());
160 }
161
162 self.access(u_idx);
164
165 self.access(v_idx);
167
168 self.nodes[u_idx].left = Some(v_idx);
171 self.nodes[v_idx].parent = Some(u_idx);
172 self.pull_up(u_idx);
173
174 self.invalidate_cache(u_idx);
176 self.invalidate_cache(v_idx);
177
178 Ok(())
179 }
180
181 #[cold]
182 #[inline(never)]
183 fn already_connected_error(&self) -> MinCutError {
184 MinCutError::InternalError("Nodes are already in the same tree".to_string())
185 }
186
187 pub fn cut(&mut self, v: NodeId) -> Result<()> {
190 let v_idx = self.get_index(v)?;
191
192 self.access(v_idx);
194
195 if let Some(left_idx) = self.nodes[v_idx].left {
197 self.nodes[v_idx].left = None;
198 self.nodes[left_idx].parent = None;
199 self.pull_up(v_idx);
200
201 self.invalidate_cache(v_idx);
203 self.invalidate_cache(left_idx);
204
205 Ok(())
206 } else {
207 Err(self.already_root_error())
208 }
209 }
210
211 #[cold]
212 #[inline(never)]
213 fn already_root_error(&self) -> MinCutError {
214 MinCutError::InternalError("Node is already a root".to_string())
215 }
216
217 #[inline]
224 pub fn find_root(&mut self, v: NodeId) -> Result<NodeId> {
225 let v_idx = self.get_index(v)?;
226
227 if let Some(&cached_root) = self.root_cache.get(&v_idx) {
229 if self.verify_root_cache(v_idx, cached_root) {
231 return Ok(self.nodes[cached_root].id);
232 }
233 }
234
235 self.access(v_idx);
238
239 let mut current = v_idx;
242 while let Some(left) = self.nodes[current].left {
243 self.push_down(current);
244 current = left;
246 }
247
248 self.splay(current);
250
251 self.root_cache.insert(v_idx, current);
253
254 Ok(self.nodes[current].id)
255 }
256
257 #[inline]
259 fn verify_root_cache(&self, _node_idx: usize, cached_root: usize) -> bool {
260 cached_root < self.nodes.len()
262 }
263
264 #[inline]
266 fn invalidate_cache(&mut self, root_idx: usize) {
267 self.root_cache.retain(|_, &mut cached| cached != root_idx);
270 }
271
272 #[inline]
274 pub fn connected(&mut self, u: NodeId, v: NodeId) -> bool {
275 if let (Ok(u_idx), Ok(v_idx)) = (self.get_index(u), self.get_index(v)) {
276 if u_idx == v_idx {
277 return true;
278 }
279 self.access(u_idx);
280 self.access(v_idx);
281 self.find_ancestor_root(u_idx) == self.find_ancestor_root(v_idx)
283 } else {
284 false
285 }
286 }
287
288 #[inline]
293 pub fn path_aggregate(&mut self, v: NodeId) -> Result<f64> {
294 let v_idx = self.get_index(v)?;
295 self.access(v_idx);
296 Ok(self.nodes[v_idx].path_aggregate)
297 }
298
299 #[inline]
301 pub fn update_value(&mut self, v: NodeId, value: f64) -> Result<()> {
302 let v_idx = self.get_index(v)?;
303 self.nodes[v_idx].value = value;
304 self.pull_up(v_idx);
305 Ok(())
306 }
307
308 pub fn lca(&mut self, u: NodeId, v: NodeId) -> Result<NodeId> {
310 let u_idx = self.get_index(u)?;
311 let v_idx = self.get_index(v)?;
312
313 self.access(u_idx);
319 let lca_idx = self.access_with_lca(v_idx);
322
323 Ok(self.nodes[lca_idx].id)
324 }
325
326 #[inline]
328 pub fn len(&self) -> usize {
329 self.nodes.len()
330 }
331
332 #[inline]
334 pub fn is_empty(&self) -> bool {
335 self.nodes.is_empty()
336 }
337
338 #[inline]
345 fn access(&mut self, v: usize) {
346 self.splay(v);
348
349 if let Some(right_idx) = self.nodes[v].right {
351 self.nodes[right_idx].path_parent = Some(v);
352 self.nodes[right_idx].parent = None;
353 }
354 self.nodes[v].right = None;
355 self.pull_up(v);
356
357 let mut current = v;
359 while let Some(pp) = self.nodes[current].path_parent {
360 self.splay(pp);
361
362 if let Some(old_right) = self.nodes[pp].right {
364 self.nodes[old_right].path_parent = Some(pp);
365 self.nodes[old_right].parent = None;
366 }
367
368 self.nodes[pp].right = Some(current);
369 self.nodes[current].parent = Some(pp);
370 self.nodes[current].path_parent = None;
371 self.pull_up(pp);
372
373 current = pp;
374 }
375
376 self.splay(v);
378 }
379
380 #[inline]
382 fn access_with_lca(&mut self, v: usize) -> usize {
383 self.splay(v);
384
385 if let Some(right_idx) = self.nodes[v].right {
386 self.nodes[right_idx].path_parent = Some(v);
387 self.nodes[right_idx].parent = None;
388 }
389 self.nodes[v].right = None;
390 self.pull_up(v);
391
392 let mut lca = v;
393 let mut current = v;
394
395 while let Some(pp) = self.nodes[current].path_parent {
396 lca = pp;
397 self.splay(pp);
398
399 if let Some(old_right) = self.nodes[pp].right {
400 self.nodes[old_right].path_parent = Some(pp);
401 self.nodes[old_right].parent = None;
402 }
403
404 self.nodes[pp].right = Some(current);
405 self.nodes[current].parent = Some(pp);
406 self.nodes[current].path_parent = None;
407 self.pull_up(pp);
408
409 current = pp;
410 }
411
412 self.splay(v);
413 lca
414 }
415
416 #[inline]
423 fn splay(&mut self, x: usize) {
424 while !self.nodes[x].is_root(&self.nodes) {
425 let p = self.nodes[x].parent.unwrap();
426
427 if self.nodes[p].is_root(&self.nodes) {
428 self.push_down(p);
430 self.push_down(x);
431 self.rotate(x);
432 } else {
433 let g = self.nodes[p].parent.unwrap();
435
436 self.push_down(g);
438 self.push_down(p);
439 self.push_down(x);
440
441 let x_is_left = self.nodes[p].left == Some(x);
444 let p_is_left = self.nodes[g].left == Some(p);
445
446 if x_is_left == p_is_left {
447 self.rotate(p);
450 self.rotate(x);
451 } else {
452 self.rotate(x);
455 self.rotate(x);
456 }
457 }
458 }
459
460 self.push_down(x);
461 }
462
463 #[inline]
469 fn rotate(&mut self, x: usize) {
470 let p = self.nodes[x].parent.unwrap();
471 let g = self.nodes[p].parent;
472
473 let pp = self.nodes[p].path_parent;
475
476 let x_is_left = self.nodes[p].left == Some(x);
477
478 if x_is_left {
479 let b = self.nodes[x].right;
481 self.nodes[p].left = b;
482 if let Some(b_idx) = b {
483 self.nodes[b_idx].parent = Some(p);
484 }
485 self.nodes[x].right = Some(p);
486 } else {
487 let b = self.nodes[x].left;
489 self.nodes[p].right = b;
490 if let Some(b_idx) = b {
491 self.nodes[b_idx].parent = Some(p);
492 }
493 self.nodes[x].left = Some(p);
494 }
495
496 self.nodes[p].parent = Some(x);
497 self.nodes[x].parent = g;
498
499 if let Some(g_idx) = g {
501 if self.nodes[g_idx].left == Some(p) {
502 self.nodes[g_idx].left = Some(x);
503 } else if self.nodes[g_idx].right == Some(p) {
504 self.nodes[g_idx].right = Some(x);
505 }
506 }
507
508 self.nodes[x].path_parent = pp;
510 self.nodes[p].path_parent = None;
511
512 self.pull_up(p);
514 self.pull_up(x);
515 }
516
517 #[inline(always)]
523 fn push_down(&mut self, x: usize) {
524 if !self.nodes[x].reversed {
525 return;
526 }
527
528 let left = self.nodes[x].left;
530 let right = self.nodes[x].right;
531 self.nodes[x].left = right;
532 self.nodes[x].right = left;
533
534 if let Some(left_idx) = left {
536 self.nodes[left_idx].reversed ^= true;
537 }
538 if let Some(right_idx) = right {
539 self.nodes[right_idx].reversed ^= true;
540 }
541
542 self.nodes[x].reversed = false;
543 }
544
545 #[inline(always)]
551 fn pull_up(&mut self, x: usize) {
552 let mut size = 1;
553 let mut aggregate = self.nodes[x].value;
554
555 if let Some(left_idx) = self.nodes[x].left {
556 size += self.nodes[left_idx].size;
557 aggregate = aggregate.min(self.nodes[left_idx].path_aggregate);
558 }
559
560 if let Some(right_idx) = self.nodes[x].right {
561 size += self.nodes[right_idx].size;
562 aggregate = aggregate.min(self.nodes[right_idx].path_aggregate);
563 }
564
565 self.nodes[x].size = size;
566 self.nodes[x].path_aggregate = aggregate;
567 }
568
569 #[inline]
574 fn find_ancestor_root(&self, mut x: usize) -> usize {
575 while let Some(p) = self.nodes[x].parent {
576 x = p;
577 }
578 while let Some(pp) = self.nodes[x].path_parent {
579 x = pp;
580 }
581 x
582 }
583
584 pub fn bulk_link(&mut self, edges: &[(NodeId, NodeId)]) -> Result<()> {
589 for &(u, v) in edges {
591 self.get_index(u)?;
592 self.get_index(v)?;
593 }
594
595 for &(u, v) in edges {
597 self.link(u, v)?;
598 }
599
600 self.root_cache.clear();
602
603 Ok(())
604 }
605
606 pub fn bulk_update(&mut self, updates: &[(NodeId, f64)]) -> Result<()> {
611 for &(id, value) in updates {
612 let idx = self.get_index(id)?;
613 self.nodes[idx].value = value;
614 }
615
616 for &(id, _) in updates {
618 let idx = self.get_index(id)?;
619 self.pull_up(idx);
620 }
621
622 Ok(())
623 }
624}
625
626impl Default for LinkCutTree {
627 fn default() -> Self {
628 Self::new()
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_make_tree() {
638 let mut lct = LinkCutTree::new();
639 let idx0 = lct.make_tree(0, 1.0);
640 let idx1 = lct.make_tree(1, 2.0);
641
642 assert_eq!(idx0, 0);
643 assert_eq!(idx1, 1);
644 assert_eq!(lct.len(), 2);
645 assert_eq!(lct.nodes[idx0].value, 1.0);
646 assert_eq!(lct.nodes[idx1].value, 2.0);
647 }
648
649 #[test]
650 fn test_link_and_find_root() {
651 let mut lct = LinkCutTree::new();
652 lct.make_tree(0, 1.0);
653 lct.make_tree(1, 2.0);
654 lct.make_tree(2, 3.0);
655
656 lct.link(0, 1).unwrap();
658 lct.link(1, 2).unwrap();
659
660 assert_eq!(lct.find_root(0).unwrap(), 2);
661 assert_eq!(lct.find_root(1).unwrap(), 2);
662 assert_eq!(lct.find_root(2).unwrap(), 2);
663 }
664
665 #[test]
666 fn test_connected() {
667 let mut lct = LinkCutTree::new();
668 lct.make_tree(0, 1.0);
669 lct.make_tree(1, 2.0);
670 lct.make_tree(2, 3.0);
671 lct.make_tree(3, 4.0);
672
673 lct.link(0, 1).unwrap();
674 lct.link(1, 2).unwrap();
675
676 assert!(lct.connected(0, 1));
677 assert!(lct.connected(0, 2));
678 assert!(lct.connected(1, 2));
679 assert!(!lct.connected(0, 3));
680 assert!(!lct.connected(2, 3));
681 }
682
683 #[test]
684 fn test_cut() {
685 let mut lct = LinkCutTree::new();
686 lct.make_tree(0, 1.0);
687 lct.make_tree(1, 2.0);
688 lct.make_tree(2, 3.0);
689
690 lct.link(0, 1).unwrap();
692 lct.link(1, 2).unwrap();
693
694 assert!(lct.connected(0, 2));
695
696 lct.cut(1).unwrap();
698
699 assert!(!lct.connected(0, 2));
700 assert!(lct.connected(0, 1));
701 assert_eq!(lct.find_root(0).unwrap(), 1);
702 assert_eq!(lct.find_root(2).unwrap(), 2);
703 }
704
705 #[test]
706 fn test_path_aggregate() {
707 let mut lct = LinkCutTree::new();
708 lct.make_tree(0, 5.0);
709 lct.make_tree(1, 3.0);
710 lct.make_tree(2, 7.0);
711 lct.make_tree(3, 2.0);
712
713 lct.link(0, 1).unwrap();
715 lct.link(1, 2).unwrap();
716 lct.link(2, 3).unwrap();
717
718 let agg = lct.path_aggregate(0).unwrap();
721 assert_eq!(agg, 2.0);
722
723 let agg = lct.path_aggregate(1).unwrap();
726 assert_eq!(agg, 2.0);
727
728 let agg = lct.path_aggregate(3).unwrap();
730 assert_eq!(agg, 2.0);
731 }
732
733 #[test]
734 fn test_update_value() {
735 let mut lct = LinkCutTree::new();
736 lct.make_tree(0, 5.0);
737 lct.make_tree(1, 3.0);
738
739 lct.link(0, 1).unwrap();
740
741 lct.update_value(0, 1.0).unwrap();
742 let agg = lct.path_aggregate(0).unwrap();
743 assert_eq!(agg, 1.0);
744 }
745
746 #[test]
747 fn test_lca() {
748 let mut lct = LinkCutTree::new();
749 for i in 0..5 {
750 lct.make_tree(i, i as f64);
751 }
752
753 lct.link(0, 1).unwrap();
762 lct.link(1, 3).unwrap();
763 lct.link(2, 4).unwrap();
764 lct.link(3, 4).unwrap();
765
766 assert!(lct.connected(0, 1), "0 and 1 should be connected");
768 assert!(lct.connected(0, 3), "0 and 3 should be connected");
769 assert!(lct.connected(0, 4), "0 and 4 should be connected");
770 assert!(lct.connected(2, 4), "2 and 4 should be connected");
771
772 let lca = lct.lca(0, 2).unwrap();
774 assert_eq!(lca, 4);
775
776 let lca = lct.lca(0, 1).unwrap();
778 assert_eq!(lca, 1);
779
780 let lca = lct.lca(0, 3).unwrap();
782 assert_eq!(lca, 3);
783 }
784
785 #[test]
786 fn test_complex_operations() {
787 let mut lct = LinkCutTree::with_capacity(10);
788
789 for i in 0..10 {
791 lct.make_tree(i, i as f64 * 0.5);
792 }
793
794 for i in 0..4 {
796 lct.link(i, i + 1).unwrap();
797 }
798
799 lct.link(5, 6).unwrap();
801 lct.link(6, 7).unwrap();
802
803 assert!(lct.connected(0, 4));
805 assert!(lct.connected(5, 7));
806 assert!(!lct.connected(0, 5));
807
808 lct.cut(2).unwrap();
810 assert!(!lct.connected(0, 4));
811 assert!(lct.connected(0, 2));
812
813 lct.link(4, 7).unwrap();
815
816 assert!(lct.connected(4, 7), "4 and 7 should be connected after link");
818 assert!(lct.connected(3, 7), "3 and 7 should be connected through 4");
819
820 }
825
826 #[test]
827 fn test_error_cases() {
828 let mut lct = LinkCutTree::new();
829 lct.make_tree(0, 1.0);
830 lct.make_tree(1, 2.0);
831
832 lct.link(0, 1).unwrap();
834 assert!(lct.link(0, 1).is_err());
835
836 assert!(lct.cut(1).is_err());
838
839 assert!(lct.find_root(99).is_err());
841 assert!(lct.link(0, 99).is_err());
842 }
843
844 #[test]
845 fn test_large_tree() {
846 let mut lct = LinkCutTree::with_capacity(1000);
847
848 for i in 0..1000 {
850 lct.make_tree(i, i as f64);
851 }
852
853 for i in 0..999 {
854 lct.link(i, i + 1).unwrap();
855 }
856
857 assert_eq!(lct.find_root(0).unwrap(), 999);
859 assert_eq!(lct.find_root(500).unwrap(), 999);
860
861 let agg = lct.path_aggregate(0).unwrap();
863 assert_eq!(agg, 0.0);
864
865 lct.cut(500).unwrap();
867 assert_eq!(lct.find_root(0).unwrap(), 500);
868 assert_eq!(lct.find_root(999).unwrap(), 999);
869 }
870
871 #[test]
872 fn test_multiple_forests() {
873 let mut lct = LinkCutTree::new();
874
875 for i in 0..9 {
877 lct.make_tree(i, i as f64);
878 }
879
880 lct.link(0, 1).unwrap();
882 lct.link(1, 2).unwrap();
883
884 lct.link(3, 4).unwrap();
886 lct.link(4, 5).unwrap();
887
888 lct.link(6, 7).unwrap();
890 lct.link(7, 8).unwrap();
891
892 assert_eq!(lct.find_root(0).unwrap(), 2);
894 assert_eq!(lct.find_root(3).unwrap(), 5);
895 assert_eq!(lct.find_root(6).unwrap(), 8);
896
897 assert!(!lct.connected(0, 3));
899 assert!(!lct.connected(3, 6));
900 assert!(!lct.connected(0, 6));
901
902 lct.link(2, 5).unwrap();
904 assert!(lct.connected(0, 5));
905 assert_eq!(lct.find_root(0).unwrap(), 5);
906 assert_eq!(lct.find_root(3).unwrap(), 5);
907 }
908
909 #[test]
910 fn test_bulk_operations() {
911 let mut lct = LinkCutTree::with_capacity(10);
912
913 for i in 0..10 {
915 lct.make_tree(i, i as f64);
916 }
917
918 let edges = vec![(0, 1), (1, 2), (2, 3)];
920 lct.bulk_link(&edges).unwrap();
921
922 assert!(lct.connected(0, 3));
923
924 let updates = vec![(0, 10.0), (1, 20.0), (2, 30.0)];
926 lct.bulk_update(&updates).unwrap();
927
928 assert_eq!(lct.nodes[0].value, 10.0);
929 assert_eq!(lct.nodes[1].value, 20.0);
930 assert_eq!(lct.nodes[2].value, 30.0);
931 }
932
933 #[test]
934 fn test_root_caching() {
935 let mut lct = LinkCutTree::with_capacity(100);
936
937 for i in 0..100 {
938 lct.make_tree(i, i as f64);
939 }
940
941 for i in 0..99 {
943 lct.link(i, i + 1).unwrap();
944 }
945
946 let root1 = lct.find_root(0).unwrap();
948 assert_eq!(root1, 99);
949
950 let root2 = lct.find_root(0).unwrap();
952 assert_eq!(root2, 99);
953
954 lct.cut(50).unwrap();
956 let root3 = lct.find_root(0).unwrap();
957 assert_eq!(root3, 50);
958 }
959}