tskit/
trees.rs

1use std::ops::Deref;
2use std::ops::DerefMut;
3
4use crate::bindings as ll_bindings;
5use crate::error::TskitError;
6use crate::sys;
7use crate::NodeId;
8use crate::SimplificationOptions;
9use crate::SizeType;
10use crate::TableOutputOptions;
11use crate::TreeFlags;
12use crate::TreeInterface;
13use crate::TreeSequenceFlags;
14use crate::TskReturnValue;
15use crate::{tsk_id_t, TableCollection};
16use ll_bindings::tsk_tree_free;
17use std::ptr::NonNull;
18
19/// A Tree.
20///
21/// Wrapper around `tsk_tree_t`.
22pub struct Tree<'treeseq> {
23    pub(crate) inner: mbox::MBox<ll_bindings::tsk_tree_t>,
24    // NOTE: this reference exists becaust tsk_tree_t
25    // contains a NON-OWNING pointer to tsk_treeseq_t.
26    // Thus, we could theoretically cause UB without
27    // tying the rust-side object liftimes together.
28    #[allow(dead_code)]
29    treeseq: &'treeseq TreeSequence,
30    api: TreeInterface,
31    current_tree: i32,
32    advanced: bool,
33}
34
35impl<'treeseq> Drop for Tree<'treeseq> {
36    fn drop(&mut self) {
37        // SAFETY: Mbox<_> cannot hold a NULL ptr
38        let rv = unsafe { tsk_tree_free(self.inner.as_mut()) };
39        assert_eq!(rv, 0);
40    }
41}
42
43impl<'treeseq> Deref for Tree<'treeseq> {
44    type Target = TreeInterface;
45    fn deref(&self) -> &Self::Target {
46        &self.api
47    }
48}
49
50impl<'treeseq> DerefMut for Tree<'treeseq> {
51    fn deref_mut(&mut self) -> &mut Self::Target {
52        &mut self.api
53    }
54}
55
56impl<'treeseq> Tree<'treeseq> {
57    fn new<F: Into<TreeFlags>>(ts: &'treeseq TreeSequence, flags: F) -> Result<Self, TskitError> {
58        let flags = flags.into();
59
60        // SAFETY: this is the type we want :)
61        let temp = unsafe {
62            libc::malloc(std::mem::size_of::<ll_bindings::tsk_tree_t>())
63                as *mut ll_bindings::tsk_tree_t
64        };
65
66        // Get our pointer into MBox ASAP
67        let nonnull = NonNull::<ll_bindings::tsk_tree_t>::new(temp)
68            .ok_or_else(|| TskitError::LibraryError("failed to malloc tsk_tree_t".to_string()))?;
69
70        // SAFETY: if temp is NULL, we have returned Err already.
71        let mut tree = unsafe { mbox::MBox::from_non_null_raw(nonnull) };
72        let mut rv =
73            unsafe { ll_bindings::tsk_tree_init(tree.as_mut(), ts.as_ptr(), flags.bits()) };
74        if rv < 0 {
75            return Err(TskitError::ErrorCode { code: rv });
76        }
77        // Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
78        if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
79            // SAFETY: nobody is null here.
80            rv = unsafe {
81                ll_bindings::tsk_tree_set_tracked_samples(
82                    tree.as_mut(),
83                    ts.num_samples().into(),
84                    (tree.as_mut()).samples,
85                )
86            };
87        }
88
89        let num_nodes = unsafe { (*(*ts.as_ptr()).tables).nodes.num_rows };
90        let api = TreeInterface::new(nonnull, num_nodes, num_nodes + 1, flags);
91        handle_tsk_return_value!(
92            rv,
93            Tree {
94                inner: tree,
95                treeseq: ts,
96                current_tree: 0,
97                advanced: false,
98                api
99            }
100        )
101    }
102}
103
104impl<'ts> streaming_iterator::StreamingIterator for Tree<'ts> {
105    type Item = Tree<'ts>;
106    fn advance(&mut self) {
107        let rv = if self.current_tree == 0 {
108            unsafe { ll_bindings::tsk_tree_first(self.as_mut_ptr()) }
109        } else {
110            unsafe { ll_bindings::tsk_tree_next(self.as_mut_ptr()) }
111        };
112        if rv == 0 {
113            self.advanced = false;
114            self.current_tree += 1;
115        } else if rv == 1 {
116            self.advanced = true;
117            self.current_tree += 1;
118        } else if rv < 0 {
119            panic_on_tskit_error!(rv);
120        }
121    }
122
123    fn get(&self) -> Option<&Self::Item> {
124        match self.advanced {
125            true => Some(self),
126            false => None,
127        }
128    }
129}
130
131impl<'ts> streaming_iterator::DoubleEndedStreamingIterator for Tree<'ts> {
132    fn advance_back(&mut self) {
133        let rv = if self.current_tree == 0 {
134            unsafe { ll_bindings::tsk_tree_last(self.as_mut_ptr()) }
135        } else {
136            unsafe { ll_bindings::tsk_tree_prev(self.as_mut_ptr()) }
137        };
138        if rv == 0 {
139            self.advanced = false;
140            self.current_tree -= 1;
141        } else if rv == 1 {
142            self.advanced = true;
143            self.current_tree -= 1;
144        } else if rv < 0 {
145            panic_on_tskit_error!(rv);
146        }
147    }
148}
149
150/// A tree sequence.
151///
152/// This is a thin wrapper around the C type `tsk_treeseq_t`.
153///
154/// When created from a [`TableCollection`], the input tables are
155/// moved into the `TreeSequence` object.
156///
157/// # Examples
158///
159/// ```
160/// let mut tables = tskit::TableCollection::new(1000.).unwrap();
161/// tables.add_node(0, 1.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL).unwrap();
162/// tables.add_node(0, 0.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL).unwrap();
163/// tables.add_node(0, 0.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL).unwrap();
164/// tables.add_edge(0., 1000., 0, 1).unwrap();
165/// tables.add_edge(0., 1000., 0, 2).unwrap();
166///
167/// // index
168/// tables.build_index();
169///
170/// // tables gets moved into our treeseq variable:
171/// let treeseq = tables.tree_sequence(tskit::TreeSequenceFlags::default()).unwrap();
172/// assert_eq!(treeseq.nodes().num_rows(), 3);
173/// assert_eq!(treeseq.edges().num_rows(), 2);
174/// ```
175///
176/// This type does not provide access to mutable tables.
177///
178/// ```compile_fail
179/// # let mut tables = tskit::TableCollection::new(1000.).unwrap();
180/// # tables.add_node(0, 1.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL).unwrap();
181/// # tables.add_node(0, 0.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL).unwrap();
182/// # tables.add_node(0, 0.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL).unwrap();
183/// # tables.add_edge(0., 1000., 0, 1).unwrap();
184/// # tables.add_edge(0., 1000., 0, 2).unwrap();
185///
186/// # // index
187/// # tables.build_index();
188///
189/// # // tables gets moved into our treeseq variable:
190/// # let treeseq = tables.tree_sequence(tskit::TreeSequenceFlags::default()).unwrap();
191/// assert_eq!(treeseq.nodes_mut().num_rows(), 3);
192/// ```
193pub struct TreeSequence {
194    pub(crate) inner: sys::LLTreeSeq,
195    views: crate::table_views::TableViews,
196}
197
198unsafe impl Send for TreeSequence {}
199unsafe impl Sync for TreeSequence {}
200
201impl TreeSequence {
202    /// Create a tree sequence from a [`TableCollection`].
203    /// In general, [`TableCollection::tree_sequence`] may be preferred.
204    /// The table collection is moved/consumed.
205    ///
206    /// # Parameters
207    ///
208    /// * `tables`, a [`TableCollection`]
209    ///
210    /// # Errors
211    ///
212    /// * [`TskitError`] if the tables are not indexed.
213    /// * [`TskitError`] if the tables are not properly sorted.
214    ///   See [`TableCollection::full_sort`](crate::TableCollection::full_sort).
215    ///
216    /// # Examples
217    ///
218    /// ```
219    /// let mut tables = tskit::TableCollection::new(1000.).unwrap();
220    /// tables.build_index();
221    /// let tree_sequence = tskit::TreeSequence::try_from(tables).unwrap();
222    /// ```
223    ///
224    /// The following may be preferred to the previous example, and more closely
225    /// mimics the Python `tskit` interface:
226    ///
227    /// ```
228    /// let mut tables = tskit::TableCollection::new(1000.).unwrap();
229    /// tables.build_index();
230    /// let tree_sequence = tables.tree_sequence(tskit::TreeSequenceFlags::default()).unwrap();
231    /// ```
232    ///
233    /// The following raises an error because the tables are not indexed:
234    ///
235    /// ```should_panic
236    /// let mut tables = tskit::TableCollection::new(1000.).unwrap();
237    /// let tree_sequence = tskit::TreeSequence::try_from(tables).unwrap();
238    /// ```
239    ///
240    /// ## Note
241    ///
242    /// This function makes *no extra copies* of the tables.
243    /// There is, however, a temporary allocation of an empty table collection
244    /// in order to convince rust that we are safely handling all memory.
245    pub fn new<F: Into<TreeSequenceFlags>>(
246        tables: TableCollection,
247        flags: F,
248    ) -> Result<Self, TskitError> {
249        let raw_tables_ptr = tables.into_raw()?;
250        let mut inner = sys::LLTreeSeq::new(raw_tables_ptr, flags.into().bits())?;
251        let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
252        Ok(Self { inner, views })
253    }
254
255    fn as_ref(&self) -> &ll_bindings::tsk_treeseq_t {
256        self.inner.as_ref()
257    }
258
259    /// Pointer to the low-level C type.
260    pub fn as_ptr(&self) -> *const ll_bindings::tsk_treeseq_t {
261        self.inner.as_ptr()
262    }
263
264    /// Mutable pointer to the low-level C type.
265    pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_treeseq_t {
266        self.inner.as_mut_ptr()
267    }
268
269    /// Dump the tree sequence to file.
270    ///
271    /// # Note
272    ///
273    /// * `options` is currently not used.  Set to default value.
274    ///   This behavior may change in a future release, which could
275    ///   break `API`.
276    ///
277    /// # Panics
278    ///
279    /// This function allocates a `CString` to pass the file name to the C API.
280    /// A panic will occur if the system runs out of memory.
281    pub fn dump<O: Into<TableOutputOptions>>(&self, filename: &str, options: O) -> TskReturnValue {
282        let c_str = std::ffi::CString::new(filename).map_err(|_| {
283            TskitError::LibraryError("call to ffi::Cstring::new failed".to_string())
284        })?;
285        self.inner
286            .dump(c_str, options.into().bits())
287            .map_err(|e| e.into())
288    }
289
290    /// Load from a file.
291    ///
292    /// This function calls [`TableCollection::new_from_file`] with
293    /// [`TreeSequenceFlags::default`].
294    pub fn load(filename: impl AsRef<str>) -> Result<Self, TskitError> {
295        let tables = TableCollection::new_from_file(filename.as_ref())?;
296
297        Self::new(tables, TreeSequenceFlags::default())
298    }
299
300    /// Obtain a copy of the [`TableCollection`].
301    /// The result is a "deep" copy of the tables.
302    ///
303    /// # Errors
304    ///
305    /// [`TskitError`] will be raised if the underlying C library returns an error code.
306    pub fn dump_tables(&self) -> Result<TableCollection, TskitError> {
307        let mut inner = crate::table_collection::uninit_table_collection();
308
309        let rv = unsafe {
310            ll_bindings::tsk_table_collection_copy((*self.as_ptr()).tables, &mut *inner, 0)
311        };
312
313        // SAFETY: we just initialized it.
314        // The C API doesn't free NULL pointers.
315        handle_tsk_return_value!(rv, unsafe { TableCollection::new_from_mbox(inner)? })
316    }
317
318    /// Create an iterator over trees.
319    ///
320    /// # Parameters
321    ///
322    /// * `flags` A [`TreeFlags`] bit field.
323    ///
324    /// # Errors
325    ///
326    /// # Examples
327    ///
328    /// ```
329    /// // You must include streaming_iterator as a dependency
330    /// // and import this type.
331    /// use streaming_iterator::StreamingIterator;
332    /// // Import this to allow .next_back() for reverse
333    /// // iteration over trees.
334    /// use streaming_iterator::DoubleEndedStreamingIterator;
335    ///
336    /// let mut tables = tskit::TableCollection::new(1000.).unwrap();
337    /// tables.build_index();
338    /// let tree_sequence = tables.tree_sequence(tskit::TreeSequenceFlags::default()).unwrap();
339    /// let mut tree_iterator = tree_sequence.tree_iterator(tskit::TreeFlags::default()).unwrap();
340    /// while let Some(tree) = tree_iterator.next() {
341    /// }
342    /// ```
343    ///
344    /// ## Coupled liftimes
345    ///
346    /// A `Tree`'s lifetime is tied to that of its tree sequence:
347    ///
348    /// ```{compile_fail}
349    /// # use streaming_iterator::StreamingIterator;
350    /// # use streaming_iterator::DoubleEndedStreamingIterator;
351    /// # let mut tables = tskit::TableCollection::new(1000.).unwrap();
352    /// # tables.build_index();
353    /// let tree_sequence = tables.tree_sequence(tskit::TreeSequenceFlags::default()).unwrap();
354    /// let mut tree_iterator = tree_sequence.tree_iterator(tskit::TreeFlags::default()).unwrap();
355    /// drop(tree_sequence);
356    /// while let Some(tree) = tree_iterator.next() { // compile fail.
357    /// }
358    /// ```
359    /// # Warning
360    ///
361    /// The following code results in an infinite loop.
362    /// Be sure to note the difference from the previous example.
363    ///
364    /// ```no_run
365    /// use streaming_iterator::StreamingIterator;
366    ///
367    /// let mut tables = tskit::TableCollection::new(1000.).unwrap();
368    /// tables.build_index();
369    /// let tree_sequence = tables.tree_sequence(tskit::TreeSequenceFlags::default()).unwrap();
370    /// while let Some(tree) = tree_sequence.tree_iterator(tskit::TreeFlags::default()).unwrap().next() {
371    /// }
372    /// ```
373    pub fn tree_iterator<F: Into<TreeFlags>>(&self, flags: F) -> Result<Tree, TskitError> {
374        let tree = Tree::new(self, flags)?;
375
376        Ok(tree)
377    }
378
379    /// Get the list of samples as a vector.
380    /// # Panics
381    ///
382    /// Will panic if the number of samples is too large to cast to a valid id.
383    #[deprecated(
384        since = "0.2.3",
385        note = "Please use TreeSequence::sample_nodes instead"
386    )]
387    pub fn samples_to_vec(&self) -> Vec<NodeId> {
388        let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
389        let mut rv = vec![];
390
391        for i in 0..num_samples {
392            let u = match isize::try_from(i) {
393                Ok(o) => NodeId::from(unsafe { *(*self.as_ptr()).samples.offset(o) }),
394                Err(e) => panic!("{}", e),
395            };
396            rv.push(u);
397        }
398        rv
399    }
400
401    /// Get the list of sample nodes as a slice.
402    pub fn sample_nodes(&self) -> &[NodeId] {
403        let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
404        sys::generate_slice(self.as_ref().samples, num_samples)
405    }
406
407    /// Get the number of trees.
408    pub fn num_trees(&self) -> SizeType {
409        self.inner.num_trees().into()
410    }
411
412    /// Calculate the average Kendall-Colijn (`K-C`) distance between
413    /// pairs of trees whose intervals overlap.
414    ///
415    /// # Note
416    ///
417    /// * [Citation](https://doi.org/10.1093/molbev/msw124)
418    ///
419    /// # Parameters
420    ///
421    /// * `lambda` specifies the relative weight of topology and branch length.
422    ///    See [`TreeInterface::kc_distance`] for more details.
423    pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result<f64, TskitError> {
424        self.inner
425            .kc_distance(&other.inner, lambda)
426            .map_err(|e| e.into())
427    }
428
429    // FIXME: document
430    pub fn num_samples(&self) -> SizeType {
431        self.inner.num_samples().into()
432    }
433
434    /// Simplify tables and return a new tree sequence.
435    ///
436    /// # Parameters
437    ///
438    /// * `samples`: a slice containing non-null node ids.
439    ///   The tables are simplified with respect to the ancestry
440    ///   of these nodes.
441    /// * `options`: A [`SimplificationOptions`] bit field controlling
442    ///   the behavior of simplification.
443    /// * `idmap`: if `true`, the return value contains a vector equal
444    ///   in length to the input node table.  For each input node,
445    ///   this vector either contains the node's new index or [`NodeId::NULL`]
446    ///   if the input node is not part of the simplified history.
447    pub fn simplify<O: Into<SimplificationOptions>>(
448        &self,
449        samples: &[NodeId],
450        options: O,
451        idmap: bool,
452    ) -> Result<(Self, Option<Vec<NodeId>>), TskitError> {
453        let mut output_node_map: Vec<NodeId> = vec![];
454        if idmap {
455            output_node_map.resize(usize::try_from(self.nodes().num_rows())?, NodeId::NULL);
456        }
457        let llsamples = unsafe {
458            std::slice::from_raw_parts(samples.as_ptr().cast::<tsk_id_t>(), samples.len())
459        };
460        let mut inner = self.inner.simplify(
461            llsamples,
462            options.into().bits(),
463            match idmap {
464                true => output_node_map.as_mut_ptr().cast::<tsk_id_t>(),
465                false => std::ptr::null_mut(),
466            },
467        )?;
468        let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
469        Ok((
470            Self { inner, views },
471            match idmap {
472                true => Some(output_node_map),
473                false => None,
474            },
475        ))
476    }
477
478    #[cfg(feature = "provenance")]
479    #[cfg_attr(doc_cfg, doc(cfg(feature = "provenance")))]
480    /// Add provenance record with a time stamp.
481    ///
482    /// All implementation of this trait provided by `tskit` use
483    /// an `ISO 8601` format time stamp
484    /// written using the [RFC 3339](https://tools.ietf.org/html/rfc3339)
485    /// specification.
486    /// This formatting approach has been the most straightforward method
487    /// for supporting round trips to/from a [`crate::provenance::ProvenanceTable`].
488    /// The implementations used here use the [`humantime`](https://docs.rs/humantime/latest/humantime/) crate.
489    ///
490    /// # Parameters
491    ///
492    /// * `record`: the provenance record
493    ///
494    /// # Examples
495    ///
496    /// ```
497    /// let mut tables = tskit::TableCollection::new(1000.).unwrap();
498    /// let mut treeseq = tables.tree_sequence(tskit::TreeSequenceFlags::BUILD_INDEXES).unwrap();
499    /// # #[cfg(feature = "provenance")] {
500    /// treeseq.add_provenance(&String::from("All your provenance r belong 2 us.")).unwrap();
501    ///
502    /// let prov_ref = treeseq.provenances();
503    /// let row_0 = prov_ref.row(0).unwrap();
504    /// assert_eq!(row_0.record, "All your provenance r belong 2 us.");
505    /// let record_0 = prov_ref.record(0).unwrap();
506    /// assert_eq!(record_0, row_0.record);
507    /// let timestamp = prov_ref.timestamp(0).unwrap();
508    /// assert_eq!(timestamp, row_0.timestamp);
509    /// use core::str::FromStr;
510    /// let dt_utc = humantime::Timestamp::from_str(&timestamp).unwrap();
511    /// println!("utc = {}", dt_utc);
512    /// # }
513    /// ```
514    pub fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
515        if record.is_empty() {
516            return Err(TskitError::ValueError {
517                got: "empty string".to_string(),
518                expected: "provenance record".to_string(),
519            });
520        }
521        let timestamp = humantime::format_rfc3339(std::time::SystemTime::now()).to_string();
522        let rv = unsafe {
523            ll_bindings::tsk_provenance_table_add_row(
524                &mut (*self.inner.as_ref().tables).provenances,
525                timestamp.as_ptr() as *mut i8,
526                timestamp.len() as ll_bindings::tsk_size_t,
527                record.as_ptr() as *mut i8,
528                record.len() as ll_bindings::tsk_size_t,
529            )
530        };
531        handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
532    }
533
534    delegate_table_view_api!();
535
536    /// Build a lending iterator over edge differences.
537    ///
538    /// # Errors
539    ///
540    /// * [`TskitError`] if the `C` back end is unable to allocate
541    ///   needed memory
542    pub fn edge_differences_iter(
543        &self,
544    ) -> Result<crate::edge_differences::EdgeDifferencesIterator, TskitError> {
545        crate::edge_differences::EdgeDifferencesIterator::new_from_treeseq(self, 0)
546    }
547}
548
549impl TryFrom<TableCollection> for TreeSequence {
550    type Error = TskitError;
551
552    fn try_from(value: TableCollection) -> Result<Self, Self::Error> {
553        Self::new(value, TreeSequenceFlags::default())
554    }
555}
556
557#[cfg(test)]
558pub(crate) mod test_trees {
559    use super::*;
560    use crate::test_fixtures::{
561        make_small_table_collection, make_small_table_collection_two_trees,
562        treeseq_from_small_table_collection, treeseq_from_small_table_collection_two_trees,
563    };
564    use crate::NodeTraversalOrder;
565    use streaming_iterator::DoubleEndedStreamingIterator;
566    use streaming_iterator::StreamingIterator;
567
568    #[test]
569    fn test_create_treeseq_new_from_tables() {
570        let tables = make_small_table_collection();
571        let treeseq = TreeSequence::new(tables, TreeSequenceFlags::default()).unwrap();
572        let samples = treeseq.sample_nodes();
573        assert_eq!(samples.len(), 2);
574        for i in 1..3 {
575            assert_eq!(samples[i - 1], NodeId::from(i as tsk_id_t));
576        }
577    }
578
579    #[test]
580    fn test_create_treeseq_from_tables() {
581        let tables = make_small_table_collection();
582        let _treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
583    }
584
585    #[test]
586    fn test_iterate_tree_seq_with_one_tree() {
587        let tables = make_small_table_collection();
588        let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
589        let mut ntrees = 0;
590        let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
591        while let Some(tree) = tree_iter.next() {
592            ntrees += 1;
593            assert_eq!(tree.current_tree, ntrees);
594            let samples = tree.sample_nodes();
595            assert_eq!(samples.len(), 2);
596            for i in 1..3 {
597                assert_eq!(samples[i - 1], NodeId::from(i as tsk_id_t));
598
599                let mut nsteps = 0;
600                for _ in tree.parents(samples[i - 1]) {
601                    nsteps += 1;
602                }
603                assert_eq!(nsteps, 2);
604            }
605
606            // These nodes are all out of range
607            for i in 100..110 {
608                let mut nsteps = 0;
609                for _ in tree.parents(i) {
610                    nsteps += 1;
611                }
612                assert_eq!(nsteps, 0);
613            }
614
615            assert_eq!(tree.parents(-1_i32).count(), 0);
616            assert_eq!(tree.children(-1_i32).count(), 0);
617
618            let roots = tree.roots_to_vec();
619            for r in roots.iter() {
620                let mut num_children = 0;
621                for _ in tree.children(*r) {
622                    num_children += 1;
623                }
624                assert_eq!(num_children, 2);
625            }
626        }
627        assert_eq!(ntrees, 1);
628    }
629
630    #[test]
631    fn test_iterate_no_roots() {
632        let mut tables = TableCollection::new(100.).unwrap();
633        tables.build_index().unwrap();
634        let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
635        let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
636        while let Some(tree) = tree_iter.next() {
637            let mut num_roots = 0;
638            for _ in tree.roots() {
639                num_roots += 1;
640            }
641            assert_eq!(num_roots, 0);
642        }
643    }
644
645    #[test]
646    fn test_samples_iterator_error_when_not_tracking_samples() {
647        let tables = make_small_table_collection();
648        let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
649
650        let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
651        if let Some(tree) = tree_iter.next() {
652            for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) {
653                match tree.samples(n) {
654                    Err(_) => (),
655                    _ => panic!("should not be Ok(_) or None"),
656                }
657            }
658        }
659    }
660
661    #[test]
662    fn test_num_tracked_samples() {
663        let treeseq = treeseq_from_small_table_collection();
664        assert_eq!(treeseq.num_samples(), 2);
665        let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
666        if let Some(tree) = tree_iter.next() {
667            assert_eq!(tree.num_tracked_samples(2).unwrap(), 1);
668            assert_eq!(tree.num_tracked_samples(1).unwrap(), 1);
669            assert_eq!(tree.num_tracked_samples(0).unwrap(), 2);
670        }
671    }
672
673    #[should_panic]
674    #[test]
675    fn test_num_tracked_samples_not_tracking_sample_counts() {
676        let treeseq = treeseq_from_small_table_collection();
677        assert_eq!(treeseq.num_samples(), 2);
678        let mut tree_iter = treeseq.tree_iterator(TreeFlags::NO_SAMPLE_COUNTS).unwrap();
679        if let Some(tree) = tree_iter.next() {
680            assert_eq!(tree.num_tracked_samples(2).unwrap(), 0);
681            assert_eq!(tree.num_tracked_samples(1).unwrap(), 0);
682            assert_eq!(tree.num_tracked_samples(0).unwrap(), 0);
683        }
684    }
685
686    #[test]
687    fn test_iterate_samples() {
688        let tables = make_small_table_collection();
689        let treeseq = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
690
691        let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
692        if let Some(tree) = tree_iter.next() {
693            assert!(!tree.flags().contains(TreeFlags::NO_SAMPLE_COUNTS));
694            assert!(tree.flags().contains(TreeFlags::SAMPLE_LISTS));
695            let mut s = vec![];
696
697            if let Ok(iter) = tree.samples(0) {
698                for i in iter {
699                    s.push(i);
700                }
701            }
702            assert_eq!(s.len(), 2);
703            assert_eq!(
704                s.len(),
705                usize::try_from(tree.num_tracked_samples(0).unwrap()).unwrap()
706            );
707            assert_eq!(s[0], 1);
708            assert_eq!(s[1], 2);
709
710            for u in 1..3 {
711                let mut s = vec![];
712                if let Ok(iter) = tree.samples(u) {
713                    for i in iter {
714                        s.push(i);
715                    }
716                }
717                assert_eq!(s.len(), 1);
718                assert_eq!(s[0], u);
719                assert_eq!(
720                    s.len(),
721                    usize::try_from(tree.num_tracked_samples(u).unwrap()).unwrap()
722                );
723            }
724        } else {
725            panic!("Expected a tree");
726        }
727    }
728
729    #[test]
730    fn test_iterate_samples_two_trees() {
731        use super::ll_bindings::tsk_size_t;
732        let treeseq = treeseq_from_small_table_collection_two_trees();
733        assert_eq!(treeseq.num_trees(), 2);
734        let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
735        let expected_number_of_roots = vec![2, 1];
736        let mut expected_root_ids = vec![
737            vec![NodeId::from(0)],
738            vec![NodeId::from(1), NodeId::from(0)],
739        ];
740        while let Some(tree) = tree_iter.next() {
741            let mut num_roots = 0;
742            let eroot_ids = expected_root_ids.pop().unwrap();
743            for (i, r) in tree.roots().enumerate() {
744                num_roots += 1;
745                assert_eq!(r, eroot_ids[i]);
746            }
747            assert_eq!(
748                expected_number_of_roots[(tree.current_tree - 1) as usize],
749                num_roots
750            );
751            assert_eq!(tree.roots().count(), eroot_ids.len());
752            let mut preoder_nodes = vec![];
753            let mut postoder_nodes = vec![];
754            for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) {
755                let mut nsamples = 0;
756                preoder_nodes.push(n);
757                if let Ok(iter) = tree.samples(n) {
758                    for _ in iter {
759                        nsamples += 1;
760                    }
761                }
762                assert!(nsamples > 0);
763                assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap());
764            }
765            for n in tree.traverse_nodes(NodeTraversalOrder::Postorder) {
766                let mut nsamples = 0;
767                postoder_nodes.push(n);
768                if let Ok(iter) = tree.samples(n) {
769                    for _ in iter {
770                        nsamples += 1;
771                    }
772                }
773                assert!(nsamples > 0);
774                assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap());
775            }
776            assert_eq!(preoder_nodes.len(), postoder_nodes.len());
777
778            // Test our preorder against the tskit functions in 0.99.15
779            {
780                let mut nodes: Vec<NodeId> = vec![
781                    NodeId::NULL;
782                    unsafe { ll_bindings::tsk_tree_get_size_bound(tree.as_ptr()) }
783                        as usize
784                ];
785                let mut num_nodes: tsk_size_t = 0;
786                let ptr = std::ptr::addr_of_mut!(num_nodes);
787                unsafe {
788                    ll_bindings::tsk_tree_preorder(
789                        tree.as_ptr(),
790                        nodes.as_mut_ptr() as *mut tsk_id_t,
791                        ptr,
792                    );
793                }
794                assert_eq!(num_nodes as usize, preoder_nodes.len());
795                for i in 0..num_nodes as usize {
796                    assert_eq!(preoder_nodes[i], nodes[i]);
797                }
798            }
799        }
800    }
801
802    #[test]
803    fn test_kc_distance_naive_test() {
804        let ts1 = treeseq_from_small_table_collection();
805        let ts2 = treeseq_from_small_table_collection();
806
807        let kc = ts1.kc_distance(&ts2, 0.0).unwrap();
808        assert!(kc.is_finite());
809        assert!((kc - 0.).abs() < f64::EPSILON);
810    }
811
812    #[test]
813    fn test_dump_tables() {
814        let tables = make_small_table_collection_two_trees();
815        // Have to make b/c tables will no longer exist after making the treeseq
816        let tables_copy = tables.deepcopy().unwrap();
817        let ts = tables.tree_sequence(TreeSequenceFlags::default()).unwrap();
818        let dumped = ts.dump_tables().unwrap();
819        assert!(tables_copy.equals(&dumped, crate::TableEqualityOptions::default()));
820    }
821
822    #[test]
823    fn test_reverse_tree_iteration() {
824        let treeseq = treeseq_from_small_table_collection_two_trees();
825        let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
826        let mut starts_fwd = vec![];
827        let mut stops_fwd = vec![];
828        let mut starts_rev = vec![];
829        let mut stops_rev = vec![];
830        while let Some(tree) = tree_iter.next() {
831            let interval = tree.interval();
832            starts_fwd.push(interval.0);
833            stops_fwd.push(interval.1);
834        }
835        assert_eq!(stops_fwd.len(), 2);
836        assert_eq!(stops_fwd.len(), 2);
837
838        // NOTE: we do NOT need to create a new iterator.
839        while let Some(tree) = tree_iter.next_back() {
840            let interval = tree.interval();
841            starts_rev.push(interval.0);
842            stops_rev.push(interval.1);
843        }
844        assert_eq!(starts_fwd.len(), starts_rev.len());
845        assert_eq!(stops_fwd.len(), stops_rev.len());
846
847        starts_rev.reverse();
848        assert!(starts_fwd == starts_rev);
849        stops_rev.reverse();
850        assert!(stops_fwd == stops_rev);
851    }
852
853    // FIXME: remove later
854    #[test]
855    fn test_array_lifetime() {
856        let treeseq = treeseq_from_small_table_collection_two_trees();
857        let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
858        if let Some(tree) = tree_iter.next() {
859            let pa = tree.parent_array();
860            let mut pc = vec![];
861            for i in pa.iter() {
862                pc.push(*i);
863            }
864            for (i, p) in pc.iter().enumerate() {
865                assert_eq!(pa[i], *p);
866            }
867        } else {
868            panic!("Expected a tree.");
869        }
870    }
871}
872
873#[cfg(test)]
874mod test_treeeseq_send_sync {
875    use crate::test_fixtures::treeseq_from_small_table_collection_two_trees;
876    use std::sync::Arc;
877    use std::thread;
878
879    #[test]
880    fn build_arc() {
881        let t = treeseq_from_small_table_collection_two_trees();
882        let a = Arc::new(t);
883        let join_handle = thread::spawn(move || a.num_trees());
884        let ntrees = join_handle.join().unwrap();
885        assert_eq!(ntrees, 2);
886    }
887}