1use std::ops::Deref;
2use std::ops::DerefMut;
3
4use crate::bindings as ll_bindings;
5use crate::error::TskitError;
6use crate::sys;
7use crate::NodeId;
8use crate::SimplificationOptions;
9use crate::SizeType;
10use crate::TableOutputOptions;
11use crate::TreeFlags;
12use crate::TreeInterface;
13use crate::TreeSequenceFlags;
14use crate::TskReturnValue;
15use crate::{tsk_id_t, TableCollection};
16use ll_bindings::tsk_tree_free;
17use std::ptr::NonNull;
18
19pub struct Tree<'treeseq> {
23 pub(crate) inner: mbox::MBox<ll_bindings::tsk_tree_t>,
24 #[allow(dead_code)]
29 treeseq: &'treeseq TreeSequence,
30 api: TreeInterface,
31 current_tree: i32,
32 advanced: bool,
33}
34
35impl<'treeseq> Drop for Tree<'treeseq> {
36 fn drop(&mut self) {
37 let rv = unsafe { tsk_tree_free(self.inner.as_mut()) };
39 assert_eq!(rv, 0);
40 }
41}
42
43impl<'treeseq> Deref for Tree<'treeseq> {
44 type Target = TreeInterface;
45 fn deref(&self) -> &Self::Target {
46 &self.api
47 }
48}
49
50impl<'treeseq> DerefMut for Tree<'treeseq> {
51 fn deref_mut(&mut self) -> &mut Self::Target {
52 &mut self.api
53 }
54}
55
56impl<'treeseq> Tree<'treeseq> {
57 fn new<F: Into<TreeFlags>>(ts: &'treeseq TreeSequence, flags: F) -> Result<Self, TskitError> {
58 let flags = flags.into();
59
60 let temp = unsafe {
62 libc::malloc(std::mem::size_of::<ll_bindings::tsk_tree_t>())
63 as *mut ll_bindings::tsk_tree_t
64 };
65
66 let nonnull = NonNull::<ll_bindings::tsk_tree_t>::new(temp)
68 .ok_or_else(|| TskitError::LibraryError("failed to malloc tsk_tree_t".to_string()))?;
69
70 let mut tree = unsafe { mbox::MBox::from_non_null_raw(nonnull) };
72 let mut rv =
73 unsafe { ll_bindings::tsk_tree_init(tree.as_mut(), ts.as_ptr(), flags.bits()) };
74 if rv < 0 {
75 return Err(TskitError::ErrorCode { code: rv });
76 }
77 if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
79 rv = unsafe {
81 ll_bindings::tsk_tree_set_tracked_samples(
82 tree.as_mut(),
83 ts.num_samples().into(),
84 (tree.as_mut()).samples,
85 )
86 };
87 }
88
89 let num_nodes = unsafe { (*(*ts.as_ptr()).tables).nodes.num_rows };
90 let api = TreeInterface::new(nonnull, num_nodes, num_nodes + 1, flags);
91 handle_tsk_return_value!(
92 rv,
93 Tree {
94 inner: tree,
95 treeseq: ts,
96 current_tree: 0,
97 advanced: false,
98 api
99 }
100 )
101 }
102}
103
104impl<'ts> streaming_iterator::StreamingIterator for Tree<'ts> {
105 type Item = Tree<'ts>;
106 fn advance(&mut self) {
107 let rv = if self.current_tree == 0 {
108 unsafe { ll_bindings::tsk_tree_first(self.as_mut_ptr()) }
109 } else {
110 unsafe { ll_bindings::tsk_tree_next(self.as_mut_ptr()) }
111 };
112 if rv == 0 {
113 self.advanced = false;
114 self.current_tree += 1;
115 } else if rv == 1 {
116 self.advanced = true;
117 self.current_tree += 1;
118 } else if rv < 0 {
119 panic_on_tskit_error!(rv);
120 }
121 }
122
123 fn get(&self) -> Option<&Self::Item> {
124 match self.advanced {
125 true => Some(self),
126 false => None,
127 }
128 }
129}
130
131impl<'ts> streaming_iterator::DoubleEndedStreamingIterator for Tree<'ts> {
132 fn advance_back(&mut self) {
133 let rv = if self.current_tree == 0 {
134 unsafe { ll_bindings::tsk_tree_last(self.as_mut_ptr()) }
135 } else {
136 unsafe { ll_bindings::tsk_tree_prev(self.as_mut_ptr()) }
137 };
138 if rv == 0 {
139 self.advanced = false;
140 self.current_tree -= 1;
141 } else if rv == 1 {
142 self.advanced = true;
143 self.current_tree -= 1;
144 } else if rv < 0 {
145 panic_on_tskit_error!(rv);
146 }
147 }
148}
149
150pub struct TreeSequence {
194 pub(crate) inner: sys::LLTreeSeq,
195 views: crate::table_views::TableViews,
196}
197
198unsafe impl Send for TreeSequence {}
199unsafe impl Sync for TreeSequence {}
200
201impl TreeSequence {
202 pub fn new<F: Into<TreeSequenceFlags>>(
246 tables: TableCollection,
247 flags: F,
248 ) -> Result<Self, TskitError> {
249 let raw_tables_ptr = tables.into_raw()?;
250 let mut inner = sys::LLTreeSeq::new(raw_tables_ptr, flags.into().bits())?;
251 let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
252 Ok(Self { inner, views })
253 }
254
255 fn as_ref(&self) -> &ll_bindings::tsk_treeseq_t {
256 self.inner.as_ref()
257 }
258
259 pub fn as_ptr(&self) -> *const ll_bindings::tsk_treeseq_t {
261 self.inner.as_ptr()
262 }
263
264 pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_treeseq_t {
266 self.inner.as_mut_ptr()
267 }
268
269 pub fn dump<O: Into<TableOutputOptions>>(&self, filename: &str, options: O) -> TskReturnValue {
282 let c_str = std::ffi::CString::new(filename).map_err(|_| {
283 TskitError::LibraryError("call to ffi::Cstring::new failed".to_string())
284 })?;
285 self.inner
286 .dump(c_str, options.into().bits())
287 .map_err(|e| e.into())
288 }
289
290 pub fn load(filename: impl AsRef<str>) -> Result<Self, TskitError> {
295 let tables = TableCollection::new_from_file(filename.as_ref())?;
296
297 Self::new(tables, TreeSequenceFlags::default())
298 }
299
300 pub fn dump_tables(&self) -> Result<TableCollection, TskitError> {
307 let mut inner = crate::table_collection::uninit_table_collection();
308
309 let rv = unsafe {
310 ll_bindings::tsk_table_collection_copy((*self.as_ptr()).tables, &mut *inner, 0)
311 };
312
313 handle_tsk_return_value!(rv, unsafe { TableCollection::new_from_mbox(inner)? })
316 }
317
318 pub fn tree_iterator<F: Into<TreeFlags>>(&self, flags: F) -> Result<Tree, TskitError> {
374 let tree = Tree::new(self, flags)?;
375
376 Ok(tree)
377 }
378
379 #[deprecated(
384 since = "0.2.3",
385 note = "Please use TreeSequence::sample_nodes instead"
386 )]
387 pub fn samples_to_vec(&self) -> Vec<NodeId> {
388 let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
389 let mut rv = vec![];
390
391 for i in 0..num_samples {
392 let u = match isize::try_from(i) {
393 Ok(o) => NodeId::from(unsafe { *(*self.as_ptr()).samples.offset(o) }),
394 Err(e) => panic!("{}", e),
395 };
396 rv.push(u);
397 }
398 rv
399 }
400
401 pub fn sample_nodes(&self) -> &[NodeId] {
403 let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
404 sys::generate_slice(self.as_ref().samples, num_samples)
405 }
406
407 pub fn num_trees(&self) -> SizeType {
409 self.inner.num_trees().into()
410 }
411
412 pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result<f64, TskitError> {
424 self.inner
425 .kc_distance(&other.inner, lambda)
426 .map_err(|e| e.into())
427 }
428
429 pub fn num_samples(&self) -> SizeType {
431 self.inner.num_samples().into()
432 }
433
434 pub fn simplify<O: Into<SimplificationOptions>>(
448 &self,
449 samples: &[NodeId],
450 options: O,
451 idmap: bool,
452 ) -> Result<(Self, Option<Vec<NodeId>>), TskitError> {
453 let mut output_node_map: Vec<NodeId> = vec![];
454 if idmap {
455 output_node_map.resize(usize::try_from(self.nodes().num_rows())?, NodeId::NULL);
456 }
457 let llsamples = unsafe {
458 std::slice::from_raw_parts(samples.as_ptr().cast::<tsk_id_t>(), samples.len())
459 };
460 let mut inner = self.inner.simplify(
461 llsamples,
462 options.into().bits(),
463 match idmap {
464 true => output_node_map.as_mut_ptr().cast::<tsk_id_t>(),
465 false => std::ptr::null_mut(),
466 },
467 )?;
468 let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
469 Ok((
470 Self { inner, views },
471 match idmap {
472 true => Some(output_node_map),
473 false => None,
474 },
475 ))
476 }
477
478 #[cfg(feature = "provenance")]
479 #[cfg_attr(doc_cfg, doc(cfg(feature = "provenance")))]
480 pub fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
515 if record.is_empty() {
516 return Err(TskitError::ValueError {
517 got: "empty string".to_string(),
518 expected: "provenance record".to_string(),
519 });
520 }
521 let timestamp = humantime::format_rfc3339(std::time::SystemTime::now()).to_string();
522 let rv = unsafe {
523 ll_bindings::tsk_provenance_table_add_row(
524 &mut (*self.inner.as_ref().tables).provenances,
525 timestamp.as_ptr() as *mut i8,
526 timestamp.len() as ll_bindings::tsk_size_t,
527 record.as_ptr() as *mut i8,
528 record.len() as ll_bindings::tsk_size_t,
529 )
530 };
531 handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
532 }
533
534 delegate_table_view_api!();
535
536 pub fn edge_differences_iter(
543 &self,
544 ) -> Result<crate::edge_differences::EdgeDifferencesIterator, TskitError> {
545 crate::edge_differences::EdgeDifferencesIterator::new_from_treeseq(self, 0)
546 }
547}
548
549impl TryFrom<TableCollection> for TreeSequence {
550 type Error = TskitError;
551
552 fn try_from(value: TableCollection) -> Result<Self, Self::Error> {
553 Self::new(value, TreeSequenceFlags::default())
554 }
555}
556
557#[cfg(test)]
558pub(crate) mod test_trees {
559 use super::*;
560 use crate::test_fixtures::{
561 make_small_table_collection, make_small_table_collection_two_trees,
562 treeseq_from_small_table_collection, treeseq_from_small_table_collection_two_trees,
563 };
564 use crate::NodeTraversalOrder;
565 use streaming_iterator::DoubleEndedStreamingIterator;
566 use streaming_iterator::StreamingIterator;
567
568 #[test]
569 fn test_create_treeseq_new_from_tables() {
570 let tables = make_small_table_collection();
571 let treeseq = TreeSequence::new(tables, TreeSequenceFlags::default()).unwrap();
572 let samples = treeseq.sample_nodes();
573 assert_eq!(samples.len(), 2);
574 for i in 1..3 {
575 assert_eq!(samples[i - 1], NodeId::from(i as tsk_id_t));
576 }
577 }
578
579 #[test]
580 fn test_create_treeseq_from_tables() {
581 let tables = make_small_table_collection();
582 let _treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
583 }
584
585 #[test]
586 fn test_iterate_tree_seq_with_one_tree() {
587 let tables = make_small_table_collection();
588 let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
589 let mut ntrees = 0;
590 let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
591 while let Some(tree) = tree_iter.next() {
592 ntrees += 1;
593 assert_eq!(tree.current_tree, ntrees);
594 let samples = tree.sample_nodes();
595 assert_eq!(samples.len(), 2);
596 for i in 1..3 {
597 assert_eq!(samples[i - 1], NodeId::from(i as tsk_id_t));
598
599 let mut nsteps = 0;
600 for _ in tree.parents(samples[i - 1]) {
601 nsteps += 1;
602 }
603 assert_eq!(nsteps, 2);
604 }
605
606 for i in 100..110 {
608 let mut nsteps = 0;
609 for _ in tree.parents(i) {
610 nsteps += 1;
611 }
612 assert_eq!(nsteps, 0);
613 }
614
615 assert_eq!(tree.parents(-1_i32).count(), 0);
616 assert_eq!(tree.children(-1_i32).count(), 0);
617
618 let roots = tree.roots_to_vec();
619 for r in roots.iter() {
620 let mut num_children = 0;
621 for _ in tree.children(*r) {
622 num_children += 1;
623 }
624 assert_eq!(num_children, 2);
625 }
626 }
627 assert_eq!(ntrees, 1);
628 }
629
630 #[test]
631 fn test_iterate_no_roots() {
632 let mut tables = TableCollection::new(100.).unwrap();
633 tables.build_index().unwrap();
634 let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
635 let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
636 while let Some(tree) = tree_iter.next() {
637 let mut num_roots = 0;
638 for _ in tree.roots() {
639 num_roots += 1;
640 }
641 assert_eq!(num_roots, 0);
642 }
643 }
644
645 #[test]
646 fn test_samples_iterator_error_when_not_tracking_samples() {
647 let tables = make_small_table_collection();
648 let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
649
650 let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
651 if let Some(tree) = tree_iter.next() {
652 for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) {
653 match tree.samples(n) {
654 Err(_) => (),
655 _ => panic!("should not be Ok(_) or None"),
656 }
657 }
658 }
659 }
660
661 #[test]
662 fn test_num_tracked_samples() {
663 let treeseq = treeseq_from_small_table_collection();
664 assert_eq!(treeseq.num_samples(), 2);
665 let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
666 if let Some(tree) = tree_iter.next() {
667 assert_eq!(tree.num_tracked_samples(2).unwrap(), 1);
668 assert_eq!(tree.num_tracked_samples(1).unwrap(), 1);
669 assert_eq!(tree.num_tracked_samples(0).unwrap(), 2);
670 }
671 }
672
673 #[should_panic]
674 #[test]
675 fn test_num_tracked_samples_not_tracking_sample_counts() {
676 let treeseq = treeseq_from_small_table_collection();
677 assert_eq!(treeseq.num_samples(), 2);
678 let mut tree_iter = treeseq.tree_iterator(TreeFlags::NO_SAMPLE_COUNTS).unwrap();
679 if let Some(tree) = tree_iter.next() {
680 assert_eq!(tree.num_tracked_samples(2).unwrap(), 0);
681 assert_eq!(tree.num_tracked_samples(1).unwrap(), 0);
682 assert_eq!(tree.num_tracked_samples(0).unwrap(), 0);
683 }
684 }
685
686 #[test]
687 fn test_iterate_samples() {
688 let tables = make_small_table_collection();
689 let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
690
691 let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
692 if let Some(tree) = tree_iter.next() {
693 assert!(!tree.flags().contains(TreeFlags::NO_SAMPLE_COUNTS));
694 assert!(tree.flags().contains(TreeFlags::SAMPLE_LISTS));
695 let mut s = vec![];
696
697 if let Ok(iter) = tree.samples(0) {
698 for i in iter {
699 s.push(i);
700 }
701 }
702 assert_eq!(s.len(), 2);
703 assert_eq!(
704 s.len(),
705 usize::try_from(tree.num_tracked_samples(0).unwrap()).unwrap()
706 );
707 assert_eq!(s[0], 1);
708 assert_eq!(s[1], 2);
709
710 for u in 1..3 {
711 let mut s = vec![];
712 if let Ok(iter) = tree.samples(u) {
713 for i in iter {
714 s.push(i);
715 }
716 }
717 assert_eq!(s.len(), 1);
718 assert_eq!(s[0], u);
719 assert_eq!(
720 s.len(),
721 usize::try_from(tree.num_tracked_samples(u).unwrap()).unwrap()
722 );
723 }
724 } else {
725 panic!("Expected a tree");
726 }
727 }
728
729 #[test]
730 fn test_iterate_samples_two_trees() {
731 use super::ll_bindings::tsk_size_t;
732 let treeseq = treeseq_from_small_table_collection_two_trees();
733 assert_eq!(treeseq.num_trees(), 2);
734 let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
735 let expected_number_of_roots = vec![2, 1];
736 let mut expected_root_ids = vec![
737 vec![NodeId::from(0)],
738 vec![NodeId::from(1), NodeId::from(0)],
739 ];
740 while let Some(tree) = tree_iter.next() {
741 let mut num_roots = 0;
742 let eroot_ids = expected_root_ids.pop().unwrap();
743 for (i, r) in tree.roots().enumerate() {
744 num_roots += 1;
745 assert_eq!(r, eroot_ids[i]);
746 }
747 assert_eq!(
748 expected_number_of_roots[(tree.current_tree - 1) as usize],
749 num_roots
750 );
751 assert_eq!(tree.roots().count(), eroot_ids.len());
752 let mut preoder_nodes = vec![];
753 let mut postoder_nodes = vec![];
754 for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) {
755 let mut nsamples = 0;
756 preoder_nodes.push(n);
757 if let Ok(iter) = tree.samples(n) {
758 for _ in iter {
759 nsamples += 1;
760 }
761 }
762 assert!(nsamples > 0);
763 assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap());
764 }
765 for n in tree.traverse_nodes(NodeTraversalOrder::Postorder) {
766 let mut nsamples = 0;
767 postoder_nodes.push(n);
768 if let Ok(iter) = tree.samples(n) {
769 for _ in iter {
770 nsamples += 1;
771 }
772 }
773 assert!(nsamples > 0);
774 assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap());
775 }
776 assert_eq!(preoder_nodes.len(), postoder_nodes.len());
777
778 {
780 let mut nodes: Vec<NodeId> = vec![
781 NodeId::NULL;
782 unsafe { ll_bindings::tsk_tree_get_size_bound(tree.as_ptr()) }
783 as usize
784 ];
785 let mut num_nodes: tsk_size_t = 0;
786 let ptr = std::ptr::addr_of_mut!(num_nodes);
787 unsafe {
788 ll_bindings::tsk_tree_preorder(
789 tree.as_ptr(),
790 nodes.as_mut_ptr() as *mut tsk_id_t,
791 ptr,
792 );
793 }
794 assert_eq!(num_nodes as usize, preoder_nodes.len());
795 for i in 0..num_nodes as usize {
796 assert_eq!(preoder_nodes[i], nodes[i]);
797 }
798 }
799 }
800 }
801
802 #[test]
803 fn test_kc_distance_naive_test() {
804 let ts1 = treeseq_from_small_table_collection();
805 let ts2 = treeseq_from_small_table_collection();
806
807 let kc = ts1.kc_distance(&ts2, 0.0).unwrap();
808 assert!(kc.is_finite());
809 assert!((kc - 0.).abs() < f64::EPSILON);
810 }
811
812 #[test]
813 fn test_dump_tables() {
814 let tables = make_small_table_collection_two_trees();
815 let tables_copy = tables.deepcopy().unwrap();
817 let ts = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
818 let dumped = ts.dump_tables().unwrap();
819 assert!(tables_copy.equals(&dumped, crate::TableEqualityOptions::default()));
820 }
821
822 #[test]
823 fn test_reverse_tree_iteration() {
824 let treeseq = treeseq_from_small_table_collection_two_trees();
825 let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
826 let mut starts_fwd = vec![];
827 let mut stops_fwd = vec![];
828 let mut starts_rev = vec![];
829 let mut stops_rev = vec![];
830 while let Some(tree) = tree_iter.next() {
831 let interval = tree.interval();
832 starts_fwd.push(interval.0);
833 stops_fwd.push(interval.1);
834 }
835 assert_eq!(stops_fwd.len(), 2);
836 assert_eq!(stops_fwd.len(), 2);
837
838 while let Some(tree) = tree_iter.next_back() {
840 let interval = tree.interval();
841 starts_rev.push(interval.0);
842 stops_rev.push(interval.1);
843 }
844 assert_eq!(starts_fwd.len(), starts_rev.len());
845 assert_eq!(stops_fwd.len(), stops_rev.len());
846
847 starts_rev.reverse();
848 assert!(starts_fwd == starts_rev);
849 stops_rev.reverse();
850 assert!(stops_fwd == stops_rev);
851 }
852
853 #[test]
855 fn test_array_lifetime() {
856 let treeseq = treeseq_from_small_table_collection_two_trees();
857 let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
858 if let Some(tree) = tree_iter.next() {
859 let pa = tree.parent_array();
860 let mut pc = vec![];
861 for i in pa.iter() {
862 pc.push(*i);
863 }
864 for (i, p) in pc.iter().enumerate() {
865 assert_eq!(pa[i], *p);
866 }
867 } else {
868 panic!("Expected a tree.");
869 }
870 }
871}
872
873#[cfg(test)]
874mod test_treeeseq_send_sync {
875 use crate::test_fixtures::treeseq_from_small_table_collection_two_trees;
876 use std::sync::Arc;
877 use std::thread;
878
879 #[test]
880 fn build_arc() {
881 let t = treeseq_from_small_table_collection_two_trees();
882 let a = Arc::new(t);
883 let join_handle = thread::spawn(move || a.num_trees());
884 let ntrees = join_handle.join().unwrap();
885 assert_eq!(ntrees, 2);
886 }
887}