winsfs_core/
lib.rs

1//! Methods for inferring the site frequency spectrum from low-quality data
2//! using various forms of the expectation-maximisation algorithm.
3
4#![warn(missing_docs)]
5use std::mem::MaybeUninit;
6
7pub mod em;
8pub mod io;
9pub mod saf;
10pub mod sfs;
11
12pub mod prelude {
13    //! The types required for common usage.
14    pub use crate::em::{
15        stopping::{LogLikelihoodTolerance, Steps},
16        Em, ParallelStandardEm, StandardEm, StreamingEm, WindowEm,
17    };
18    pub use crate::saf::Saf;
19    pub use crate::sfs::{Sfs, USfs};
20}
21
22/// Sets the number of threads to use for parallelization.
23///
24/// This is a thin wrapper around [`rayon::ThreadPoolBuilder`] to save users from having to
25/// import `rayon` to control parallelism. The meaning of the `threads` parameter here derives
26/// from [`rayon::ThreadPoolBuilder::num_threads`], see it's documentation for details.
27pub fn set_threads(threads: usize) -> Result<(), rayon::ThreadPoolBuildError> {
28    rayon::ThreadPoolBuilder::new()
29        .num_threads(threads)
30        .build_global()
31}
32
33/// This is an internal implementation detail.
34#[doc(hidden)]
35#[macro_export]
36macro_rules! matrix {
37    ($([$($x:literal),+ $(,)?]),+ $(,)?) => {{
38        let cols = vec![$($crate::matrix!(count: $($x),+)),+];
39        assert!(cols.windows(2).all(|w| w[0] == w[1]));
40        let vec = vec![$($($x),+),+];
41        (cols, vec)
42    }};
43    (count: $($x:expr),+) => {
44        <[()]>::len(&[$($crate::matrix!(replace: $x)),*])
45    };
46    (replace: $x:expr) => {()};
47}
48
49pub(crate) trait ArrayExt<const N: usize, T> {
50    // TODO: Use each_ref when stable,
51    // see github.com/rust-lang/rust/issues/76118
52    fn by_ref(&self) -> [&T; N];
53
54    // TODO: Use each_mut when stable,
55    // see github.com/rust-lang/rust/issues/76118
56    fn by_mut(&mut self) -> [&mut T; N];
57
58    // TODO: Use zip when stable,
59    // see github.com/rust-lang/rust/issues/80094
60    fn array_zip<U>(self, rhs: [U; N]) -> [(T, U); N];
61}
62
63impl<const N: usize, T> ArrayExt<N, T> for [T; N] {
64    fn by_ref(&self) -> [&T; N] {
65        // Adapted from code in tracking issue, see above.
66        let mut out: MaybeUninit<[&T; N]> = MaybeUninit::uninit();
67
68        let buf = out.as_mut_ptr() as *mut &T;
69        let mut refs = self.iter();
70
71        for i in 0..N {
72            unsafe { buf.add(i).write(refs.next().unwrap()) }
73        }
74
75        unsafe { out.assume_init() }
76    }
77
78    fn by_mut(&mut self) -> [&mut T; N] {
79        // Adapted from code in tracking issue, see above.
80        let mut out: MaybeUninit<[&mut T; N]> = MaybeUninit::uninit();
81
82        let buf = out.as_mut_ptr() as *mut &mut T;
83        let mut refs = self.iter_mut();
84
85        for i in 0..N {
86            unsafe { buf.add(i).write(refs.next().unwrap()) }
87        }
88
89        unsafe { out.assume_init() }
90    }
91
92    fn array_zip<U>(self, rhs: [U; N]) -> [(T, U); N] {
93        // Adapted from code in implementation PR, see github.com/rust-lang/rust/pull/79451
94        let mut dst = MaybeUninit::<[(T, U); N]>::uninit();
95
96        let ptr = dst.as_mut_ptr() as *mut (T, U);
97
98        for (idx, (lhs, rhs)) in self.into_iter().zip(rhs.into_iter()).enumerate() {
99            unsafe { ptr.add(idx).write((lhs, rhs)) }
100        }
101
102        unsafe { dst.assume_init() }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_by_ref() {
112        assert_eq!([1, 2, 3].by_ref(), [&1, &2, &3]);
113    }
114
115    #[test]
116    fn test_by_mut() {
117        assert_eq!([1, 2, 3].by_mut(), [&mut 1, &mut 2, &mut 3]);
118    }
119
120    #[test]
121    fn test_zip() {
122        assert_eq!(
123            [1, 2, 3].array_zip([0.1, 0.2, 0.3]),
124            [(1, 0.1), (2, 0.2), (3, 0.3)],
125        )
126    }
127}