winsfs_core/
sfs.rs

1//! Multi-dimensional site frequency spectra ("SFS").
2//!
3//! The central type is the [`SfsBase`] struct, which represents an SFS with a dimensionality
4//! that may or may not be known at compile time, and which may or may not be normalised to
5//! probability scale. Type aliases [`Sfs`], [`USfs`], [`DynSfs`], and [`DynUSfs`] are exposed
6//! for convenience.
7
8use std::{
9    cmp::Ordering,
10    error::Error,
11    fmt::{self, Write as _},
12    marker::PhantomData,
13    ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
14    slice,
15};
16
17use crate::ArrayExt;
18
19pub mod generics;
20use generics::{ConstShape, DynShape, Norm, Normalisation, Shape, Unnorm};
21
22pub mod io;
23
24pub mod iter;
25use iter::Indices;
26
27mod em;
28
29const NORMALISATION_TOLERANCE: f64 = 10. * f64::EPSILON;
30
31/// Creates an unnormalised 1D SFS.
32///
33/// This is mainly intended for readability in doc-tests, but may also be useful elsewhere.
34///
35/// # Examples
36///
37/// Create SFS by repeating an element:
38///
39/// ```
40/// use winsfs_core::sfs1d;
41/// let sfs = sfs1d![0.1; 10];
42/// assert!(sfs.iter().all(|&x| x == 0.1));
43/// ```
44///
45/// Create SFS from a list of elements:
46///
47/// ```
48/// use winsfs_core::sfs1d;
49/// let sfs = sfs1d![0.1, 0.2, 0.3];
50/// assert_eq!(sfs[[0]], 0.1);
51/// assert_eq!(sfs[[1]], 0.2);
52/// assert_eq!(sfs[[2]], 0.3);
53/// ```
54#[macro_export]
55macro_rules! sfs1d {
56    ($elem:expr; $n:expr) => {
57        $crate::sfs::USfs::from_elem($elem, [$n])
58    };
59    ($($x:expr),+ $(,)?) => {
60        $crate::sfs::USfs::from_vec(vec![$($x),+])
61    };
62}
63
64/// Creates an unnormalised 2D SFS.
65///
66/// This is mainly intended for readability in doc-tests, but may also be useful elsewhere.
67///
68/// # Examples
69///
70/// ```
71/// use winsfs_core::sfs2d;
72/// let sfs = sfs2d![
73///     [0.1, 0.2, 0.3],
74///     [0.4, 0.5, 0.6],
75///     [0.7, 0.8, 0.9],
76/// ];
77/// assert_eq!(sfs[[0, 0]], 0.1);
78/// assert_eq!(sfs[[1, 0]], 0.4);
79/// assert_eq!(sfs[[2, 0]], 0.7);
80/// ```
81#[macro_export]
82macro_rules! sfs2d {
83    ($([$($x:literal),+ $(,)?]),+ $(,)?) => {{
84        let (cols, vec) = $crate::matrix!($([$($x),+]),+);
85        let shape = [cols.len(), cols[0]];
86        $crate::sfs::SfsBase::from_vec_shape(vec, shape).unwrap()
87    }};
88}
89
90/// An multi-dimensional site frequency spectrum ("SFS").
91///
92/// Elements are stored in row-major order: the last index varies the fastest.
93///
94/// The number of dimensions of the SFS may either be known at compile-time or run-time,
95/// and this is governed by the [`Shape`] trait. Moreover, the SFS may or may not be normalised
96/// to probability scale, and this is controlled by the [`Normalisation`] trait.
97/// See also the [`Sfs`], [`USfs`], [`DynSfs`], and [`DynUSfs`] type aliases.
98#[derive(Clone, Debug, PartialEq)]
99// TODO: Replace normalisation with const enum once these are permitted in const generics,
100// see github.com/rust-lang/rust/issues/95174
101pub struct SfsBase<S: Shape, N: Normalisation> {
102    values: Vec<f64>,
103    pub(crate) shape: S,
104    pub(crate) strides: S,
105    norm: PhantomData<N>,
106}
107
108/// A normalised SFS with shape known at compile-time.
109pub type Sfs<const D: usize> = SfsBase<ConstShape<D>, Norm>;
110
111/// An unnormalised SFS with shape known at compile-time.
112pub type USfs<const D: usize> = SfsBase<ConstShape<D>, Unnorm>;
113
114/// A normalised SFS with shape known at run-time.
115pub type DynSfs = SfsBase<DynShape, Norm>;
116
117/// An unnormalised SFS with shape known at run-time.
118pub type DynUSfs = SfsBase<DynShape, Unnorm>;
119
120impl<S: Shape, N: Normalisation> SfsBase<S, N> {
121    /// Returns the values of the SFS as a flat, row-major slice.
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use winsfs_core::sfs2d;
127    /// let sfs = sfs2d![
128    ///     [0., 1., 2.],
129    ///     [3., 4., 5.],
130    /// ];
131    /// assert_eq!(sfs.as_slice(), [0., 1., 2., 3., 4., 5.]);
132    /// ```
133    #[inline]
134    pub fn as_slice(&self) -> &[f64] {
135        &self.values
136    }
137
138    /// Returns a folded version of the SFS.
139    ///
140    /// Folding is useful when the spectrum has not been properly polarised, so that there is
141    /// no meaningful distinction between having 0 and 2N (in the diploid case) variants at a site.
142    /// The folding operation collapses these indistinguishable bins by adding the value from the
143    /// lower part of the spectrum onto the upper, and setting the lower value to zero.
144    ///
145    /// Note that we adopt the convention that on the "diagonal" of the SFS, where there is less of
146    /// a convention on what is the correct way of folding, the arithmetic mean of the candidates is
147    /// used. The examples below illustrate this.
148    ///
149    /// # Examples
150    ///
151    /// Folding in 1D:
152    ///
153    /// ```
154    /// use winsfs_core::sfs1d;
155    /// let sfs = sfs1d![5., 2., 3., 10., 1.];
156    /// assert_eq!(sfs.fold(), sfs1d![6., 12., 3., 0., 0.]);
157    /// ```
158    ///
159    /// Folding in 2D (square input):
160    ///
161    /// ```
162    /// use winsfs_core::sfs2d;
163    /// let sfs = sfs2d![
164    ///     [4., 2., 10.],
165    ///     [0., 3., 4.],
166    ///     [7., 2., 1.],
167    /// ];
168    /// let expected = sfs2d![
169    ///     [5., 4., 8.5],
170    ///     [4., 3., 0.],
171    ///     [8.5, 0., 0.],
172    /// ];
173    /// assert_eq!(sfs.fold(), expected);
174    /// ```
175    ///
176    /// Folding in 2D (non-square input):
177    ///
178    /// ```
179    /// use winsfs_core::sfs2d;
180    /// let sfs = sfs2d![
181    ///     [4., 2., 10.],
182    ///     [0., 3., 4.],
183    /// ];
184    /// let expected = sfs2d![
185    ///     [8., 5., 0.],
186    ///     [10., 0., 0.],
187    /// ];
188    /// assert_eq!(sfs.fold(), expected);
189    /// ```    
190    pub fn fold(&self) -> Self {
191        let n = self.values.len();
192        let total_count = self.shape.iter().sum::<usize>() - self.shape.len();
193
194        // In general, this point divides the folding line. Since we are folding onto the "upper"
195        // part of the array, we want to fold anything "below" it onto something "above" it.
196        let mid_count = total_count / 2;
197
198        // The spectrum may or may not have a "diagonal", i.e. a hyperplane that falls exactly on
199        // the midpoint. If such a diagonal exists, we need to handle it as a special case when
200        // folding below.
201        //
202        // For example, in 1D a spectrum with five elements has a "diagonal", marked X:
203        // [-, -, X, -, -]
204        // Whereas on with four elements would not.
205        //
206        // In two dimensions, e.g. three-by-three elements has a diagonal:
207        // [-, -, X]
208        // [-, X, -]
209        // [X, -, -]
210        // whereas two-by-three would not. On the other hand, two-by-four has a diagonal:
211        // [-, -, X, -]
212        // [-, X, -, -]
213        //
214        // Note that even-ploidy data should always have a diagonal, whereas odd-ploidy data
215        // may or may not.
216        let has_diagonal = total_count % 2 == 0;
217
218        // Note that we cannot use the algorithm below in-place, since the reverse iterator
219        // may reach elements that have already been folded, which causes bugs. Hence we fold
220        // into a zero-initialised copy.
221        let mut folded = Self::new_unchecked(vec![0.0; n], self.shape.clone());
222
223        // We iterate over indices rather than values since we have to mutate on the array
224        // while looking at it from both directions.
225        (0..n).zip((0..n).rev()).for_each(|(i, rev_i)| {
226            let count = compute_index_sum_unchecked(i, n, self.shape.as_ref());
227
228            match (count.cmp(&mid_count), has_diagonal) {
229                (Ordering::Less, _) | (Ordering::Equal, false) => {
230                    // We are in the upper part of the spectrum that should be folded onto.
231                    folded.values[i] = self.values[i] + self.values[rev_i];
232                }
233                (Ordering::Equal, true) => {
234                    // We are on a diagonal, which must be handled as a special case:
235                    // there are apparently different opinions on what the most correct
236                    // thing to do is. This adopts the same strategy as e.g. in dadi.
237                    folded.values[i] = 0.5 * self.values[i] + 0.5 * self.values[rev_i];
238                }
239                (Ordering::Greater, _) => (),
240            }
241        });
242
243        folded
244    }
245
246    /// Returns a string containing a flat, row-major represention of the SFS.
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use winsfs_core::sfs1d;
252    /// let sfs = sfs1d![0.0, 0.1, 0.2];
253    /// assert_eq!(sfs.format_flat(" ", 1), "0.0 0.1 0.2");
254    /// ```
255    ///
256    /// ```
257    /// use winsfs_core::sfs2d;
258    /// let  sfs = sfs2d![[0.01, 0.12], [0.23, 0.34]];
259    /// assert_eq!(sfs.format_flat(",", 2), "0.01,0.12,0.23,0.34");
260    /// ```
261    pub fn format_flat(&self, sep: &str, precision: usize) -> String {
262        if let Some(first) = self.values.first() {
263            let cap = self.values.len() * (precision + 3);
264            let mut init = String::with_capacity(cap);
265            write!(init, "{first:.precision$}").unwrap();
266            // init.push_str(&format!("{:.precision$}", first));
267
268            self.iter().skip(1).fold(init, |mut s, x| {
269                s.push_str(sep);
270                write!(s, "{x:.precision$}").unwrap();
271                s
272            })
273        } else {
274            String::new()
275        }
276    }
277
278    /// Returns a value at an index in the SFS.
279    ///
280    /// If the index is out of bounds, `None` is returned.
281    ///
282    /// # Examples
283    ///
284    /// ```
285    /// use winsfs_core::sfs1d;
286    /// let sfs = sfs1d![0.0, 0.1, 0.2];
287    /// assert_eq!(sfs.get(&[0]), Some(&0.0));
288    /// assert_eq!(sfs.get(&[1]), Some(&0.1));
289    /// assert_eq!(sfs.get(&[2]), Some(&0.2));
290    /// assert_eq!(sfs.get(&[3]), None);
291    /// ```
292    ///
293    /// ```
294    /// use winsfs_core::sfs2d;
295    /// let sfs = sfs2d![[0.0, 0.1, 0.2], [0.3, 0.4, 0.5], [0.6, 0.7, 0.8]];
296    /// assert_eq!(sfs.get(&[0, 0]), Some(&0.0));
297    /// assert_eq!(sfs.get(&[1, 2]), Some(&0.5));
298    /// assert_eq!(sfs.get(&[3, 0]), None);
299    /// ```
300    #[inline]
301    pub fn get(&self, index: &S) -> Option<&f64> {
302        self.values.get(compute_flat(index, &self.shape)?)
303    }
304
305    /// Returns a normalised SFS, consuming `self`.
306    ///
307    /// This works purely on the type level, and does not modify the actual values in the SFS.
308    /// If the SFS is not already normalised, an error is returned. To modify the SFS to become
309    /// normalised, see [`Sfs::normalise`].
310    ///
311    /// # Examples
312    ///
313    /// An unnormalised SFS with values summing to one can be turned into a normalised SFS:
314    ///
315    /// ```
316    /// use winsfs_core::{sfs1d, sfs::{Sfs, USfs}};
317    /// let sfs: USfs<1> = sfs1d![0.2; 5];
318    /// let sfs: Sfs<1> = sfs.into_normalised().unwrap();
319    /// ```
320    ///
321    /// Otherwise, an unnormalised SFS cannot be normalised SFS using this method:
322    ///
323    /// ```
324    /// use winsfs_core::{sfs1d, sfs::USfs};
325    /// let sfs: USfs<1> = sfs1d![2.; 5];
326    /// assert!(sfs.into_normalised().is_err());
327    /// ```
328    ///
329    /// Use [`Sfs::normalise`] instead.
330    #[inline]
331    pub fn into_normalised(self) -> Result<SfsBase<S, Norm>, NormError> {
332        let sum = self.sum();
333
334        if (sum - 1.).abs() <= NORMALISATION_TOLERANCE {
335            Ok(self.into_normalised_unchecked())
336        } else {
337            Err(NormError { sum })
338        }
339    }
340
341    #[inline]
342    fn into_normalised_unchecked(self) -> SfsBase<S, Norm> {
343        SfsBase {
344            values: self.values,
345            shape: self.shape,
346            strides: self.strides,
347            norm: PhantomData,
348        }
349    }
350
351    /// Returns an unnormalised SFS, consuming `self`.
352    ///
353    /// This works purely on the type level, and does not modify the actual values in the SFS.
354    ///
355    /// # Examples
356    ///
357    /// ```
358    /// use winsfs_core::sfs::{Sfs, USfs};
359    /// let sfs: Sfs<1> = Sfs::uniform([7]);
360    /// let sfs: USfs<1> = sfs.into_unnormalised();
361    /// ```
362    #[inline]
363    pub fn into_unnormalised(self) -> SfsBase<S, Unnorm> {
364        SfsBase {
365            values: self.values,
366            shape: self.shape,
367            strides: self.strides,
368            norm: PhantomData,
369        }
370    }
371
372    /// Returns an iterator over the elements in the SFS in row-major order.
373    ///
374    /// # Examples
375    ///
376    /// ```
377    /// use winsfs_core::sfs2d;
378    /// let sfs = sfs2d![
379    ///     [0., 1., 2.],
380    ///     [3., 4., 5.],
381    ///     [6., 7., 8.],
382    /// ];
383    /// let expected = (0..9).map(|x| x as f64);
384    /// assert!(sfs.iter().zip(expected).all(|(&x, y)| x == y));
385    /// ```
386    #[inline]
387    pub fn iter(&self) -> slice::Iter<'_, f64> {
388        self.values.iter()
389    }
390
391    /// Creates a new SFS.
392    #[inline]
393    fn new_unchecked(values: Vec<f64>, shape: S) -> Self {
394        let strides = shape.strides();
395
396        Self {
397            values,
398            shape,
399            strides,
400            norm: PhantomData,
401        }
402    }
403
404    /// Returns an unnormalised SFS scaled by some constant, consuming `self`.
405    ///
406    /// # Examples
407    ///
408    /// ```
409    /// use winsfs_core::sfs1d;
410    /// assert_eq!(
411    ///     sfs1d![0., 1.,  2.,  3.,  4.].scale(10.),
412    ///     sfs1d![0., 10., 20., 30., 40.],
413    /// );
414    /// ```
415    #[inline]
416    #[must_use = "returns scaled SFS, doesn't modify in-place"]
417    pub fn scale(mut self, scale: f64) -> SfsBase<S, Unnorm> {
418        self.values.iter_mut().for_each(|x| *x *= scale);
419
420        self.into_unnormalised()
421    }
422
423    /// Returns the SFS shape.
424    ///
425    /// # Examples
426    ///
427    /// ```
428    /// use winsfs_core::sfs2d;
429    /// let sfs = sfs2d![
430    ///     [0., 1., 2.],
431    ///     [3., 4., 5.],
432    /// ];
433    /// assert_eq!(sfs.shape(), &[2, 3]);
434    /// ```
435    pub fn shape(&self) -> &S {
436        &self.shape
437    }
438
439    /// Returns the sum of values in the SFS.
440    #[inline]
441    fn sum(&self) -> f64 {
442        self.iter().sum()
443    }
444}
445
446impl<const D: usize, N: Normalisation> SfsBase<ConstShape<D>, N> {
447    /// Returns an iterator over the sample frequencies of the SFS in row-major order.
448    ///
449    /// Note that this is *not* the contents of SFS, but the frequencies corresponding
450    /// to the indices. See [`Sfs::iter`] for an iterator over the SFS values themselves.
451    ///
452    /// # Examples
453    ///
454    /// ```
455    /// use winsfs_core::sfs::Sfs;
456    /// let sfs = Sfs::uniform([2, 3]);
457    /// let mut iter = sfs.frequencies();
458    /// assert_eq!(iter.next(), Some([0., 0.]));
459    /// assert_eq!(iter.next(), Some([0., 0.5]));
460    /// assert_eq!(iter.next(), Some([0., 1.]));
461    /// assert_eq!(iter.next(), Some([1., 0.]));
462    /// assert_eq!(iter.next(), Some([1., 0.5]));
463    /// assert_eq!(iter.next(), Some([1., 1.]));
464    /// assert!(iter.next().is_none());
465    /// ```
466    pub fn frequencies(&self) -> impl Iterator<Item = [f64; D]> {
467        let n_arr = self.shape.map(|n| n - 1);
468        self.indices()
469            .map(move |idx_arr| idx_arr.array_zip(n_arr).map(|(i, n)| i as f64 / n as f64))
470    }
471
472    /// Returns an iterator over the indices in the SFS in row-major order.
473    ///
474    /// # Examples
475    ///
476    /// ```
477    /// use winsfs_core::sfs::Sfs;
478    /// let sfs = Sfs::uniform([2, 3]);
479    /// let mut iter = sfs.indices();
480    /// assert_eq!(iter.next(), Some([0, 0]));
481    /// assert_eq!(iter.next(), Some([0, 1]));
482    /// assert_eq!(iter.next(), Some([0, 2]));
483    /// assert_eq!(iter.next(), Some([1, 0]));
484    /// assert_eq!(iter.next(), Some([1, 1]));
485    /// assert_eq!(iter.next(), Some([1, 2]));
486    /// assert!(iter.next().is_none());
487    /// ```
488    pub fn indices(&self) -> Indices<ConstShape<D>> {
489        Indices::from_shape(self.shape)
490    }
491}
492
493impl<S: Shape> SfsBase<S, Norm> {
494    /// Creates a new, normalised, and uniform SFS.
495    ///
496    /// # Examples
497    ///
498    /// ```
499    /// use winsfs_core::sfs::Sfs;
500    /// let sfs = Sfs::uniform([2, 5]);
501    /// assert!(sfs.iter().all(|&x| x == 0.1));
502    /// ```
503    pub fn uniform(shape: S) -> SfsBase<S, Norm> {
504        let n: usize = shape.iter().product();
505
506        let elem = 1.0 / n as f64;
507
508        SfsBase::new_unchecked(vec![elem; n], shape)
509    }
510}
511
512impl<S: Shape> SfsBase<S, Unnorm> {
513    /// Returns the a mutable reference values of the SFS as a flat, row-major slice.
514    ///
515    /// # Examples
516    ///
517    /// ```
518    /// use winsfs_core::sfs2d;
519    /// let mut sfs = sfs2d![
520    ///     [0., 1., 2.],
521    ///     [3., 4., 5.],
522    /// ];
523    /// assert_eq!(sfs.as_slice(), [0., 1., 2., 3., 4., 5.]);
524    /// sfs.as_mut_slice()[0] = 100.;
525    /// assert_eq!(sfs.as_slice(), [100., 1., 2., 3., 4., 5.]);
526    /// ```
527    #[inline]
528    pub fn as_mut_slice(&mut self) -> &mut [f64] {
529        &mut self.values
530    }
531
532    /// Creates a new, unnormalised SFS by repeating a single value.
533    ///
534    /// See also [`Sfs::uniform`] to create a normalised SFS with uniform values.
535    ///
536    /// # Examples
537    ///
538    /// ```
539    /// use winsfs_core::sfs::USfs;
540    /// let sfs = USfs::from_elem(0.1, [7, 5]);
541    /// assert_eq!(sfs.shape(), &[7, 5]);
542    /// assert!(sfs.iter().all(|&x| x == 0.1));
543    /// ```
544    pub fn from_elem(elem: f64, shape: S) -> Self {
545        let n = shape.iter().product();
546
547        Self::new_unchecked(vec![elem; n], shape)
548    }
549
550    /// Creates a new, unnormalised SFS from an iterator.
551    ///
552    /// # Examples
553    ///
554    /// ```
555    /// use winsfs_core::sfs::USfs;
556    /// let iter = (0..9).map(|x| x as f64);
557    /// let sfs = USfs::from_iter_shape(iter, [3, 3]).expect("shape didn't fit iterator!");
558    /// assert_eq!(sfs[[1, 2]], 5.0);
559    /// ```
560    pub fn from_iter_shape<I>(iter: I, shape: S) -> Result<Self, ShapeError<S>>
561    where
562        I: IntoIterator<Item = f64>,
563    {
564        Self::from_vec_shape(iter.into_iter().collect(), shape)
565    }
566
567    /// Creates a new, unnormalised SFS from a vector.
568    ///
569    /// # Examples
570    ///
571    /// ```
572    /// use winsfs_core::sfs::USfs;
573    /// let vec: Vec<f64> = (0..9).map(|x| x as f64).collect();
574    /// let sfs = USfs::from_vec_shape(vec, [3, 3]).expect("shape didn't fit vector!");
575    /// assert_eq!(sfs[[2, 0]], 6.0);
576    /// ```
577    pub fn from_vec_shape(vec: Vec<f64>, shape: S) -> Result<Self, ShapeError<S>> {
578        let n: usize = shape.iter().product();
579
580        match vec.len() == n {
581            true => Ok(Self::new_unchecked(vec, shape)),
582            false => Err(ShapeError::new(n, shape)),
583        }
584    }
585
586    /// Returns a mutable reference to a value at an index in the SFS.
587    ///
588    /// If the index is out of bounds, `None` is returned.
589    ///
590    /// # Examples
591    ///
592    /// ```
593    /// use winsfs_core::sfs1d;
594    /// let mut sfs = sfs1d![0.0, 0.1, 0.2];
595    /// assert_eq!(sfs[[0]], 0.0);
596    /// if let Some(v) = sfs.get_mut(&[0]) {
597    ///     *v = 0.5;
598    /// }
599    /// assert_eq!(sfs[[0]], 0.5);
600    /// ```
601    ///
602    /// ```
603    /// use winsfs_core::sfs2d;
604    /// let mut sfs = sfs2d![[0.0, 0.1, 0.2], [0.3, 0.4, 0.5], [0.6, 0.7, 0.8]];
605    /// assert_eq!(sfs[[0, 0]], 0.0);
606    /// if let Some(v) = sfs.get_mut(&[0, 0]) {
607    ///     *v = 0.5;
608    /// }
609    /// assert_eq!(sfs[[0, 0]], 0.5);
610    /// ```
611    #[inline]
612    pub fn get_mut(&mut self, index: &S) -> Option<&mut f64> {
613        self.values.get_mut(compute_flat(index, &self.shape)?)
614    }
615
616    /// Returns an iterator over mutable references to the elements in the SFS in row-major order.
617    #[inline]
618    pub fn iter_mut(&mut self) -> slice::IterMut<'_, f64> {
619        self.values.iter_mut()
620    }
621
622    /// Returns a normalised SFS, consuming `self`.
623    ///
624    /// The values in the SFS are modified to sum to one.
625    ///
626    /// # Examples
627    ///
628    /// ```
629    /// use winsfs_core::{sfs1d, sfs::{Sfs, USfs}};
630    /// let sfs: USfs<1> = sfs1d![0., 1., 2., 3., 4.];
631    /// let sfs: Sfs<1> = sfs.normalise();
632    /// assert_eq!(sfs[[1]], 0.1);
633    /// ```
634    #[inline]
635    #[must_use = "returns normalised SFS, doesn't modify in-place"]
636    pub fn normalise(mut self) -> SfsBase<S, Norm> {
637        let sum = self.sum();
638
639        self.iter_mut().for_each(|x| *x /= sum);
640
641        self.into_normalised_unchecked()
642    }
643
644    /// Creates a new, unnnormalised SFS with all entries set to zero.
645    ///
646    /// # Examples
647    ///
648    /// ```
649    /// use winsfs_core::sfs::USfs;
650    /// let sfs = USfs::zeros([2, 5]);
651    /// assert!(sfs.iter().all(|&x| x == 0.0));
652    /// ```
653    pub fn zeros(shape: S) -> Self {
654        Self::from_elem(0.0, shape)
655    }
656}
657
658impl SfsBase<ConstShape<1>, Unnorm> {
659    /// Creates a new SFS from a vector.
660    ///
661    /// # Examples
662    ///
663    /// ```
664    /// use winsfs_core::sfs::USfs;
665    /// let sfs = USfs::from_vec(vec![0., 1., 2.]);
666    /// assert_eq!(sfs.shape(), &[3]);
667    /// assert_eq!(sfs[[1]], 1.);
668    /// ```
669    pub fn from_vec(values: Vec<f64>) -> Self {
670        let shape = [values.len()];
671
672        Self::new_unchecked(values, shape)
673    }
674}
675
676impl SfsBase<ConstShape<2>, Norm> {
677    /// Returns the f2-statistic.
678    ///
679    /// # Examples
680    ///
681    /// ```
682    /// use winsfs_core::sfs2d;
683    /// let sfs = sfs2d![
684    ///     [1., 0., 0.],
685    ///     [0., 1., 0.],
686    ///     [0., 0., 1.],
687    /// ].normalise();
688    /// assert_eq!(sfs.f2(), 0.);
689    /// ```
690    pub fn f2(&self) -> f64 {
691        self.iter()
692            .zip(self.frequencies())
693            .map(|(v, [f_i, f_j])| v * (f_i - f_j).powi(2))
694            .sum()
695    }
696}
697
698macro_rules! impl_op {
699    ($trait:ident, $method:ident, $assign_trait:ident, $assign_method:ident) => {
700        impl<S: Shape, N: Normalisation> $assign_trait<&SfsBase<S, N>> for SfsBase<S, Unnorm> {
701            #[inline]
702            fn $assign_method(&mut self, rhs: &SfsBase<S, N>) {
703                assert_eq!(self.shape, rhs.shape);
704
705                self.iter_mut()
706                    .zip(rhs.iter())
707                    .for_each(|(x, rhs)| x.$assign_method(rhs));
708            }
709        }
710
711        impl<S: Shape, N: Normalisation> $assign_trait<SfsBase<S, N>> for SfsBase<S, Unnorm> {
712            #[inline]
713            fn $assign_method(&mut self, rhs: SfsBase<S, N>) {
714                self.$assign_method(&rhs);
715            }
716        }
717
718        impl<S: Shape, N: Normalisation, M: Normalisation> $trait<SfsBase<S, M>> for SfsBase<S, N> {
719            type Output = SfsBase<S, Unnorm>;
720
721            #[inline]
722            fn $method(self, rhs: SfsBase<S, M>) -> Self::Output {
723                let mut sfs = self.into_unnormalised();
724                sfs.$assign_method(&rhs);
725                sfs
726            }
727        }
728
729        impl<S: Shape, N: Normalisation, M: Normalisation> $trait<&SfsBase<S, M>>
730            for SfsBase<S, N>
731        {
732            type Output = SfsBase<S, Unnorm>;
733
734            #[inline]
735            fn $method(self, rhs: &SfsBase<S, M>) -> Self::Output {
736                let mut sfs = self.into_unnormalised();
737                sfs.$assign_method(rhs);
738                sfs
739            }
740        }
741    };
742}
743impl_op!(Add, add, AddAssign, add_assign);
744impl_op!(Sub, sub, SubAssign, sub_assign);
745
746impl<S: Shape, N: Normalisation> Index<S> for SfsBase<S, N> {
747    type Output = f64;
748
749    #[inline]
750    fn index(&self, index: S) -> &Self::Output {
751        self.get(&index).unwrap()
752    }
753}
754
755impl<S: Shape> IndexMut<S> for SfsBase<S, Unnorm> {
756    #[inline]
757    fn index_mut(&mut self, index: S) -> &mut Self::Output {
758        self.get_mut(&index).unwrap()
759    }
760}
761
762impl<const D: usize, N: Normalisation> From<SfsBase<ConstShape<D>, N>> for SfsBase<DynShape, N> {
763    fn from(sfs: SfsBase<ConstShape<D>, N>) -> Self {
764        SfsBase {
765            values: sfs.values,
766            shape: sfs.shape.into(),
767            strides: sfs.strides.into(),
768            norm: PhantomData,
769        }
770    }
771}
772
773impl<const D: usize, N: Normalisation> TryFrom<SfsBase<DynShape, N>> for SfsBase<ConstShape<D>, N> {
774    type Error = SfsBase<DynShape, N>;
775
776    fn try_from(sfs: SfsBase<DynShape, N>) -> Result<Self, Self::Error> {
777        match (
778            <[usize; D]>::try_from(&sfs.shape[..]),
779            <[usize; D]>::try_from(&sfs.strides[..]),
780        ) {
781            (Ok(shape), Ok(strides)) => Ok(SfsBase {
782                values: sfs.values,
783                shape,
784                strides,
785                norm: PhantomData,
786            }),
787            (Err(_), Err(_)) => Err(sfs),
788            (Ok(_), Err(_)) | (Err(_), Ok(_)) => {
789                unreachable!("conversion of dyn shape and strides succeeds or fails together")
790            }
791        }
792    }
793}
794
795/// An error associated with SFS construction using invalid shape.
796#[derive(Clone, Copy, Debug)]
797pub struct ShapeError<S: Shape> {
798    n: usize,
799    shape: S,
800}
801
802impl<S: Shape> ShapeError<S> {
803    fn new(n: usize, shape: S) -> Self {
804        Self { n, shape }
805    }
806}
807
808impl<S: Shape> fmt::Display for ShapeError<S> {
809    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
810        let shape_fmt = self
811            .shape
812            .iter()
813            .map(|x| x.to_string())
814            .collect::<Vec<_>>()
815            .join("/");
816        let n = self.n;
817        let d = self.shape.as_ref().len();
818
819        write!(
820            f,
821            "cannot create {d}D SFS with shape {shape_fmt} from {n} elements"
822        )
823    }
824}
825
826impl<S: Shape> Error for ShapeError<S> {}
827
828/// An error associated with normalised SFS construction using unnormalised input.
829#[derive(Clone, Copy, Debug)]
830pub struct NormError {
831    sum: f64,
832}
833
834impl fmt::Display for NormError {
835    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
836        write!(
837            f,
838            "cannot create normalised SFS using values summing to {}",
839            self.sum
840        )
841    }
842}
843
844impl Error for NormError {}
845
846fn compute_flat<S: Shape>(index: &S, shape: &S) -> Option<usize> {
847    assert_eq!(index.len(), shape.len());
848
849    for i in 1..index.len() {
850        if index.as_ref()[i] >= shape.as_ref()[i] {
851            return None;
852        }
853    }
854    Some(compute_flat_unchecked(index, shape))
855}
856
857fn compute_flat_unchecked<S: Shape>(index: &S, shape: &S) -> usize {
858    let mut flat = index.as_ref()[0];
859    for i in 1..index.len() {
860        flat *= shape.as_ref()[i];
861        flat += index.as_ref()[i];
862    }
863    flat
864}
865
866fn compute_index_sum_unchecked(mut flat: usize, mut n: usize, shape: &[usize]) -> usize {
867    let mut sum = 0;
868    for v in shape {
869        n /= v;
870        sum += flat / n;
871        flat %= n;
872    }
873    sum
874}
875
876#[cfg(test)]
877mod tests {
878    use super::*;
879
880    #[test]
881    fn test_index_1d() {
882        let sfs = sfs1d![0., 1., 2., 3., 4., 5.];
883        assert_eq!(sfs.get(&[0]), Some(&0.));
884        assert_eq!(sfs.get(&[2]), Some(&2.));
885        assert_eq!(sfs.get(&[5]), Some(&5.));
886        assert_eq!(sfs.get(&[6]), None);
887    }
888
889    #[test]
890    fn test_index_2d() {
891        let sfs = sfs2d![[0., 1., 2.], [3., 4., 5.]];
892        assert_eq!(sfs.get(&[0, 0]), Some(&0.));
893        assert_eq!(sfs.get(&[1, 0]), Some(&3.));
894        assert_eq!(sfs.get(&[1, 1]), Some(&4.));
895        assert_eq!(sfs.get(&[1, 2]), Some(&5.));
896        assert_eq!(sfs.get(&[2, 0]), None);
897        assert_eq!(sfs.get(&[0, 3]), None);
898    }
899
900    #[test]
901    fn test_f2() {
902        #[rustfmt::skip]
903        let sfs = sfs2d![
904            [0., 1., 2.],
905            [3., 4., 5.]
906        ].normalise();
907        assert!((sfs.f2() - 0.4166667).abs() < 1e-6);
908    }
909
910    #[test]
911    fn test_sfs_addition() {
912        let mut lhs = sfs1d![0., 1., 2.];
913        let rhs = sfs1d![5., 6., 7.];
914        let sum = sfs1d![5., 7., 9.];
915
916        assert_eq!(lhs.clone() + rhs.clone(), sum);
917        assert_eq!(lhs.clone() + &rhs, sum);
918
919        lhs += rhs.clone();
920        assert_eq!(lhs, sum);
921        lhs += &rhs;
922        assert_eq!(lhs, sum + rhs);
923    }
924
925    #[test]
926    fn test_sfs_subtraction() {
927        let mut lhs = sfs1d![5., 6., 7.];
928        let rhs = sfs1d![0., 1., 2.];
929        let sub = sfs1d![5., 5., 5.];
930
931        assert_eq!(lhs.clone() - rhs.clone(), sub);
932        assert_eq!(lhs.clone() - &rhs, sub);
933
934        lhs -= rhs.clone();
935        assert_eq!(lhs, sub);
936        lhs -= &rhs;
937        assert_eq!(lhs, sub - rhs);
938    }
939
940    #[test]
941    fn test_fold_4() {
942        let sfs = sfs1d![0., 1., 2., 3.];
943
944        assert_eq!(sfs.fold(), sfs1d![3., 3., 0., 0.],);
945    }
946
947    #[test]
948    fn test_fold_5() {
949        let sfs = sfs1d![0., 1., 2., 3., 4.];
950
951        assert_eq!(sfs.fold(), sfs1d![4., 4., 2., 0., 0.],);
952    }
953
954    #[test]
955    fn test_fold_3x3() {
956        #[rustfmt::skip]
957        let sfs = sfs2d![
958            [0., 1., 2.],
959            [3., 4., 5.],
960            [6., 7., 8.],
961        ];
962
963        #[rustfmt::skip]
964        let expected = sfs2d![
965            [8., 8., 4.],
966            [8., 4., 0.],
967            [4., 0., 0.],
968        ];
969
970        assert_eq!(sfs.fold(), expected);
971    }
972
973    #[test]
974    fn test_fold_2x4() {
975        #[rustfmt::skip]
976        let sfs = sfs2d![
977            [0., 1., 2., 3.],
978            [4., 5., 6., 7.],
979        ];
980
981        #[rustfmt::skip]
982        let expected = sfs2d![
983            [7., 7.,  3.5, 0.],
984            [7., 3.5, 0.,  0.],
985        ];
986
987        assert_eq!(sfs.fold(), expected);
988    }
989
990    #[test]
991    fn test_fold_3x4() {
992        #[rustfmt::skip]
993        let sfs = sfs2d![
994            [0., 1.,  2.,  3.],
995            [4., 5.,  6.,  7.],
996            [8., 9., 10., 11.],
997        ];
998
999        #[rustfmt::skip]
1000        let expected = sfs2d![
1001            [11., 11., 11., 0.],
1002            [11., 11.,  0., 0.],
1003            [11.,  0.,  0., 0.],
1004        ];
1005
1006        assert_eq!(sfs.fold(), expected);
1007    }
1008
1009    #[test]
1010    fn test_fold_3x7() {
1011        #[rustfmt::skip]
1012        let sfs = sfs2d![
1013            [ 0.,  1.,  2.,  3.,  4.,  5.,  6.],
1014            [ 7.,  8.,  9., 10., 11., 12., 13.],
1015            [14., 15., 16., 17., 18., 19., 20.],
1016        ];
1017
1018        #[rustfmt::skip]
1019        let expected = sfs2d![
1020            [20., 20., 20., 20., 10., 0., 0.],
1021            [20., 20., 20., 10.,  0., 0., 0.],
1022            [20., 20., 10.,  0.,  0., 0., 0.],
1023        ];
1024
1025        assert_eq!(sfs.fold(), expected);
1026    }
1027
1028    #[test]
1029    fn test_fold_2x2x2() {
1030        let sfs = USfs::from_iter_shape((0..8).map(|x| x as f64), [2, 2, 2]).unwrap();
1031
1032        #[rustfmt::skip]
1033        let expected = USfs::from_vec_shape(
1034            vec![
1035                7., 7.,
1036                7., 0.,
1037                
1038                7., 0.,
1039                0., 0.,
1040            ],
1041            [2, 2, 2]
1042        ).unwrap();
1043
1044        assert_eq!(sfs.fold(), expected);
1045    }
1046
1047    #[test]
1048    fn test_fold_2x3x2() {
1049        let sfs = USfs::from_iter_shape((0..12).map(|x| x as f64), [2, 3, 2]).unwrap();
1050
1051        #[rustfmt::skip]
1052        let expected = USfs::from_vec_shape(
1053            vec![
1054                11., 11.,  
1055                11.,  5.5,
1056                5.5,  0.,
1057                
1058                11.,  5.5,
1059                 5.5, 0.,
1060                 0.,  0.,
1061            ],
1062            [2, 3, 2]
1063        ).unwrap();
1064
1065        assert_eq!(sfs.fold(), expected);
1066    }
1067
1068    #[test]
1069    fn test_fold_3x3x3() {
1070        let sfs = USfs::from_iter_shape((0..27).map(|x| x as f64), [3, 3, 3]).unwrap();
1071
1072        #[rustfmt::skip]
1073        let expected = USfs::from_vec_shape(
1074            vec![
1075                26., 26., 26.,
1076                26., 26., 13.,
1077                26., 13.,  0.,
1078                
1079                26., 26., 13.,
1080                26., 13.,  0.,
1081                13.,  0.,  0.,
1082
1083                26., 13.,  0.,
1084                13.,  0.,  0.,
1085                 0.,  0.,  0.,
1086            ],
1087            [3, 3, 3]
1088        ).unwrap();
1089
1090        assert_eq!(sfs.fold(), expected);
1091    }
1092}