sfs_core/
spectrum.rs

1//! Frequency and count spectra.
2
3use std::{
4    fmt,
5    marker::PhantomData,
6    ops::{AddAssign, Index, IndexMut, Range},
7};
8
9mod count;
10pub use count::Count;
11
12pub mod io;
13
14pub mod iter;
15use iter::FrequenciesIter;
16
17mod folded;
18pub use folded::Folded;
19
20pub(crate) mod project;
21use project::Projection;
22pub use project::ProjectionError;
23
24mod stat;
25pub use stat::StatisticError;
26
27use crate::array::{Array, Axis, Shape, ShapeError};
28
29mod seal {
30    #![deny(missing_docs)]
31    pub trait Sealed {}
32}
33use seal::Sealed;
34
35/// A type that can be used as marker for the state of a [`Spectrum`].
36///
37/// This trait is sealed and cannot be implemented outside this crate.
38pub trait State: Sealed {
39    #[doc(hidden)]
40    fn debug_name() -> &'static str;
41}
42
43/// A marker struct for a [`Spectrum`] of frequencies.
44///
45/// See also [`Sfs`].
46#[derive(Copy, Clone, Debug)]
47pub struct Frequencies;
48impl Sealed for Frequencies {}
49impl State for Frequencies {
50    fn debug_name() -> &'static str {
51        "Sfs"
52    }
53}
54
55/// A marker struct for a [`Spectrum`] of counts.
56///
57/// See also [`Scs`].
58#[derive(Copy, Clone, Debug, Eq, PartialEq)]
59pub struct Counts;
60impl Sealed for Counts {}
61impl State for Counts {
62    fn debug_name() -> &'static str {
63        "Scs"
64    }
65}
66
67/// A site frequency spectrum.
68pub type Sfs = Spectrum<Frequencies>;
69
70/// A site count spectrum.
71pub type Scs = Spectrum<Counts>;
72
73/// A site spectrum.
74///
75/// The spectrum may either be over frequencies ([`Sfs`]) or counts ([`Scs`]).
76#[derive(PartialEq)]
77pub struct Spectrum<S: State> {
78    array: Array<f64>,
79    state: PhantomData<S>,
80}
81
82impl<S: State> Spectrum<S> {
83    /// Returns the number of dimensions of the spectrum.
84    pub fn dimensions(&self) -> usize {
85        self.array.dimensions()
86    }
87
88    /// Returns the number of elements in the spectrum.
89    pub fn elements(&self) -> usize {
90        self.array.elements()
91    }
92
93    /// Returns a folded spectrum.
94    pub fn fold(&self) -> Folded<S> {
95        Folded::from_spectrum(self)
96    }
97
98    /// Returns the underlying array.
99    pub fn inner(&self) -> &Array<f64> {
100        &self.array
101    }
102
103    /// Returns a normalized frequency spectrum, consuming `self`.
104    pub fn into_normalized(mut self) -> Sfs {
105        self.normalize();
106        self.into_state_unchecked()
107    }
108
109    fn into_state_unchecked<R: State>(self) -> Spectrum<R> {
110        Spectrum {
111            array: self.array,
112            state: PhantomData,
113        }
114    }
115
116    /// Returns an iterator over the allele frequencies of the elements in the spectrum in row-major
117    /// order.
118    ///
119    /// Note that this is not an iterator over frequencies in the sense of a frequency spectrum, but
120    /// in the sense of allele frequencies corresponding to indices in a spectrum.
121    pub fn iter_frequencies(&self) -> FrequenciesIter<'_> {
122        FrequenciesIter::new(self)
123    }
124
125    /// Returns the King statistic.
126    ///
127    /// See Manichaikul (2010) and Waples (2019) for details.
128    ///
129    /// # Errors
130    ///
131    /// If the spectrum is not a 3x3 2-dimensional spectrum.
132    pub fn king(&self) -> Result<f64, StatisticError> {
133        stat::King::from_spectrum(self)
134            .map(|x| x.0)
135            .map_err(Into::into)
136    }
137
138    /// Returns a spectrum with the provided axes marginalized out.
139    ///
140    /// # Errors
141    ///
142    /// If the provided axes contain duplicates, or if any of them are out of bounds.
143    pub fn marginalize(&self, axes: &[Axis]) -> Result<Self, MarginalizationError> {
144        if let Some(duplicate) = axes.iter().enumerate().find_map(|(i, axis)| {
145            axes.get(i + 1..)
146                .and_then(|slice| slice.contains(axis).then_some(axis))
147        }) {
148            return Err(MarginalizationError::DuplicateAxis { axis: duplicate.0 });
149        };
150
151        if let Some(out_of_bounds) = axes.iter().find(|axis| axis.0 >= self.dimensions()) {
152            return Err(MarginalizationError::AxisOutOfBounds {
153                axis: out_of_bounds.0,
154                dimensions: self.dimensions(),
155            });
156        };
157
158        if axes.len() >= self.dimensions() {
159            return Err(MarginalizationError::TooManyAxes {
160                axes: axes.len(),
161                dimensions: self.dimensions(),
162            });
163        }
164
165        let is_sorted = axes.windows(2).all(|w| w[0] <= w[1]);
166        if is_sorted {
167            Ok(self.marginalize_unchecked(axes))
168        } else {
169            let mut axes = axes.to_vec();
170            axes.sort();
171            Ok(self.marginalize_unchecked(&axes))
172        }
173    }
174
175    fn marginalize_axis(&self, axis: Axis) -> Self {
176        Scs::from(self.array.sum(axis)).into_state_unchecked()
177    }
178
179    fn marginalize_unchecked(&self, axes: &[Axis]) -> Self {
180        let mut spectrum = self.clone();
181
182        // As we marginalize out axes one by one, the axes shift down,
183        // so we subtract the number already removed and rely on axes having been sorted
184        axes.iter()
185            .enumerate()
186            .map(|(removed, original)| Axis(original.0 - removed))
187            .for_each(|axis| {
188                spectrum = spectrum.marginalize_axis(axis);
189            });
190
191        spectrum
192    }
193
194    /// Normalizes the spectrum to frequencies in-place.
195    ///
196    /// See also [`Spectrum::into_normalized`] to normalize and convert to an [`Sfs`] at the
197    /// type-level.
198    pub fn normalize(&mut self) {
199        let sum = self.sum();
200        self.array.iter_mut().for_each(|x| *x /= sum);
201    }
202
203    /// Returns the average number of pairwise differences, also known as π.
204    ///
205    /// # Errors
206    ///
207    /// If the spectrum is not a 1-dimensional spectrum.
208    pub fn pi(&self) -> Result<f64, StatisticError> {
209        stat::Pi::from_spectrum(self)
210            .map(|x| x.0)
211            .map_err(Into::into)
212    }
213
214    /// Returns the average number of pairwise differences between two populations, also known as
215    /// πₓᵧ or Dₓᵧ.
216    ///
217    /// See Nei and Li (1987).
218    ///
219    /// # Errors
220    ///
221    /// If the spectrum is not a 1-dimensional spectrum.
222    pub fn pi_xy(&self) -> Result<f64, StatisticError> {
223        stat::PiXY::from_spectrum(self)
224            .map(|x| x.0)
225            .map_err(Into::into)
226    }
227
228    /// Returns a spectrum projected down to a shape.
229    ///
230    /// The projection is based on hypergeometric down-sampling. See Marth (2004) and Gutenkunst
231    /// (2009) for details. Note that projecting a spectrum after creation may cause problems;
232    /// prefer projecting site-wise during creation where possible.
233    ///
234    /// # Errors
235    ///
236    /// Errors if the projected shape is not valid for the provided spectrum.
237    pub fn project<T>(&self, project_to: T) -> Result<Self, ProjectionError>
238    where
239        T: Into<Shape>,
240    {
241        let project_to = project_to.into();
242        let mut projection = Projection::from_shapes(self.shape().clone(), project_to.clone())?;
243        let mut new = Scs::from_zeros(project_to);
244
245        for (&weight, from) in self.array.iter().zip(self.array.iter_indices().map(Count)) {
246            projection
247                .project_unchecked(&from)
248                .into_weighted(weight)
249                .add_unchecked(&mut new);
250        }
251
252        Ok(new.into_state_unchecked())
253    }
254
255    /// Returns the R0 statistic.
256    ///
257    /// See Waples (2019) for details.
258    ///
259    /// # Errors
260    ///
261    /// If the spectrum is not a 3x3 2-dimensional spectrum.
262    pub fn r0(&self) -> Result<f64, StatisticError> {
263        stat::R0::from_spectrum(self)
264            .map(|x| x.0)
265            .map_err(Into::into)
266    }
267
268    /// Returns the R0 statistic.
269    ///
270    /// See Waples (2019) for details.
271    ///
272    /// # Errors
273    ///
274    /// If the spectrum is not a 3x3 2-dimensional spectrum.
275    pub fn r1(&self) -> Result<f64, StatisticError> {
276        stat::R1::from_spectrum(self)
277            .map(|x| x.0)
278            .map_err(Into::into)
279    }
280
281    /// Returns the shape of the spectrum.
282    pub fn shape(&self) -> &Shape {
283        self.array.shape()
284    }
285
286    /// Returns the sum of elements in the spectrum.
287    pub fn sum(&self) -> f64 {
288        self.array.iter().sum::<f64>()
289    }
290
291    /// Returns Watterson's estimator of the mutation-scaled effective population size θ.
292    ///
293    /// # Errors
294    ///
295    /// If the spectrum is not a 1-dimensional spectrum.
296    pub fn theta_watterson(&self) -> Result<f64, StatisticError> {
297        stat::Theta::<stat::theta::Watterson>::from_spectrum(self)
298            .map(|x| x.0)
299            .map_err(Into::into)
300    }
301}
302
303impl Scs {
304    /// Returns Fu and Li's D difference statistic.
305    ///
306    /// See Fu and Li (1993).
307    ///
308    /// # Errors
309    ///
310    /// If the spectrum is not a 1-dimensional spectrum.
311    pub fn d_fu_li(&self) -> Result<f64, StatisticError> {
312        stat::D::<stat::d::FuLi>::from_scs(self)
313            .map(|x| x.0)
314            .map_err(Into::into)
315    }
316
317    /// Returns Tajima's D difference statistic.
318    ///
319    /// See Tajima (1989).
320    ///
321    /// # Errors
322    ///
323    /// If the spectrum is not a 1-dimensional spectrum.
324    pub fn d_tajima(&self) -> Result<f64, StatisticError> {
325        stat::D::<stat::d::Tajima>::from_scs(self)
326            .map(|x| x.0)
327            .map_err(Into::into)
328    }
329
330    /// Creates a new spectrum from a range and a shape.
331    ///
332    /// This is mainly intended for testing and illustration.
333    ///
334    /// # Errors
335    ///
336    /// If the number of items in the range does not match the provided shape.
337    pub fn from_range<S>(range: Range<usize>, shape: S) -> Result<Self, ShapeError>
338    where
339        S: Into<Shape>,
340    {
341        Array::from_iter(range.map(|v| v as f64), shape).map(Self::from)
342    }
343
344    /// Creates a new one-dimensional spectrum from a vector.
345    pub fn from_vec<T>(vec: T) -> Self
346    where
347        T: Into<Vec<f64>>,
348    {
349        let vec = vec.into();
350        let shape = vec.len();
351        Self::new(vec, shape).unwrap()
352    }
353
354    /// Creates a new spectrum filled with zeros to a shape.
355    pub fn from_zeros<S>(shape: S) -> Self
356    where
357        S: Into<Shape>,
358    {
359        Self::from(Array::from_zeros(shape))
360    }
361
362    /// Returns a mutable reference to the underlying array.
363    pub fn inner_mut(&mut self) -> &mut Array<f64> {
364        &mut self.array
365    }
366
367    /// Creates a new spectrum from data in row-major order and a shape.
368    ///
369    /// # Errors
370    ///
371    /// If the number of items in the data does not match the provided shape.
372    pub fn new<D, S>(data: D, shape: S) -> Result<Self, ShapeError>
373    where
374        D: Into<Vec<f64>>,
375        S: Into<Shape>,
376    {
377        Array::new(data, shape).map(Self::from)
378    }
379
380    /// Returns the number of sites segregating in any population in the spectrum.
381    pub fn segregating_sites(&self) -> f64 {
382        let n = self.elements();
383
384        self.array.iter().take(n - 1).skip(1).sum()
385    }
386}
387
388impl Sfs {
389    /// Returns the f₂ statistic.
390    ///
391    /// See Reich (2009) and Peter (2016) for details.
392    ///
393    /// # Errors
394    ///
395    /// If the spectrum is not a 2-dimensional spectrum.
396    pub fn f2(&self) -> Result<f64, StatisticError> {
397        stat::F2::from_sfs(self).map(|x| x.0).map_err(Into::into)
398    }
399
400    /// Returns the f₃(A; B, C)-statistic, where A, B, C is in the order of the populations in the
401    /// spectrum.
402    ///
403    /// Note that f₃ may also be calculated as a linear combination of f₂, which is often going to
404    /// be more efficient and flexible.
405    ///
406    /// See Reich (2009) and Peter (2016) for details.
407    ///
408    /// # Errors
409    ///
410    /// If the spectrum is not a 3-dimensional spectrum.
411    pub fn f3(&self) -> Result<f64, StatisticError> {
412        stat::F3::from_sfs(self).map(|x| x.0).map_err(Into::into)
413    }
414
415    /// Returns the f₄(A, B; C, D)-statistic, where A, B, C is in the order of the populations in
416    /// the spectrum.
417    ///
418    /// Note that f₄ may also be calculated as a linear combination of f₂, which is often going to
419    /// be more efficient and flexible.
420    ///
421    /// See Reich (2009) and Peter (2016) for details.
422    ///
423    /// # Errors
424    ///
425    /// If the spectrum is not a 4-dimensional spectrum.
426    pub fn f4(&self) -> Result<f64, StatisticError> {
427        stat::F4::from_sfs(self).map(|x| x.0).map_err(Into::into)
428    }
429
430    /// Returns Hudson's estimator of Fst.
431    ///
432    /// See Bhatia (2013) for details. (This uses a "ratio of estimates" as recommended there.)
433    ///
434    /// # Errors
435    ///
436    /// If the spectrum is not a 2-dimensional spectrum.
437    pub fn fst(&self) -> Result<f64, StatisticError> {
438        stat::Fst::from_sfs(self).map(|x| x.0).map_err(Into::into)
439    }
440}
441
442impl<S: State> Clone for Spectrum<S> {
443    fn clone(&self) -> Self {
444        Self {
445            array: self.array.clone(),
446            state: PhantomData,
447        }
448    }
449}
450
451impl<S: State> fmt::Debug for Spectrum<S> {
452    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453        f.debug_struct(S::debug_name())
454            .field("array", &self.array)
455            .finish()
456    }
457}
458
459impl AddAssign<&Count> for Scs {
460    fn add_assign(&mut self, count: &Count) {
461        self[count] += 1.0;
462    }
463}
464
465impl From<Array<f64>> for Scs {
466    fn from(array: Array<f64>) -> Self {
467        Self {
468            array,
469            state: PhantomData,
470        }
471    }
472}
473
474impl<I, S: State> Index<I> for Spectrum<S>
475where
476    I: AsRef<[usize]>,
477{
478    type Output = f64;
479
480    fn index(&self, index: I) -> &Self::Output {
481        self.array.index(index)
482    }
483}
484
485impl<I, S: State> IndexMut<I> for Spectrum<S>
486where
487    I: AsRef<[usize]>,
488{
489    fn index_mut(&mut self, index: I) -> &mut Self::Output {
490        self.array.index_mut(index)
491    }
492}
493
494/// An error associated with marginalizing a spectrum.
495#[derive(Debug, Eq, PartialEq)]
496pub enum MarginalizationError {
497    /// An axis is duplicated.
498    DuplicateAxis {
499        /// The index of the duplicated axis.
500        axis: usize,
501    },
502    /// An axis is out of bounds.
503    AxisOutOfBounds {
504        /// The axis that is out of bounds.
505        axis: usize,
506        /// The number of dimensions in the spectrum.
507        dimensions: usize,
508    },
509    /// Too many axes provided.
510    TooManyAxes {
511        /// The number of provided axes.
512        axes: usize,
513        /// The number of dimensions in the spectrum.
514        dimensions: usize,
515    },
516}
517
518impl fmt::Display for MarginalizationError {
519    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
520        match self {
521            MarginalizationError::DuplicateAxis { axis } => {
522                write!(f, "cannot marginalize with duplicate axis {axis}")
523            }
524            MarginalizationError::AxisOutOfBounds { axis, dimensions } => write!(
525                f,
526                "cannot marginalize axis {axis} in spectrum with {dimensions} dimensions"
527            ),
528            MarginalizationError::TooManyAxes { axes, dimensions } => write!(
529                f,
530                "cannot marginalize a total of {axes} axes in spectrum with {dimensions} dimensions"
531            ),
532        }
533    }
534}
535
536impl std::error::Error for MarginalizationError {}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    use crate::approx::ApproxEq;
543
544    impl<S: State> ApproxEq for Spectrum<S> {
545        const DEFAULT_EPSILON: Self::Epsilon = <f64 as ApproxEq>::DEFAULT_EPSILON;
546
547        type Epsilon = <f64 as ApproxEq>::Epsilon;
548
549        fn approx_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
550            self.array.approx_eq(&other.array, epsilon)
551        }
552    }
553
554    #[test]
555    fn test_marginalize_axis_2d() {
556        let scs = Scs::from_range(0..9, [3, 3]).unwrap();
557
558        assert_eq!(
559            scs.marginalize_axis(Axis(0)),
560            Scs::new([9., 12., 15.], 3).unwrap()
561        );
562
563        assert_eq!(
564            scs.marginalize_axis(Axis(1)),
565            Scs::new([3., 12., 21.], 3).unwrap()
566        );
567    }
568
569    #[test]
570    fn test_marginalize_axis_3d() {
571        let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
572
573        assert_eq!(
574            scs.marginalize_axis(Axis(0)),
575            Scs::new([27., 30., 33., 36., 39., 42., 45., 48., 51.], [3, 3]).unwrap()
576        );
577
578        assert_eq!(
579            scs.marginalize_axis(Axis(1)),
580            Scs::new([9., 12., 15., 36., 39., 42., 63., 66., 69.], [3, 3]).unwrap()
581        );
582
583        assert_eq!(
584            scs.marginalize_axis(Axis(2)),
585            Scs::new([3., 12., 21., 30., 39., 48., 57., 66., 75.], [3, 3]).unwrap()
586        );
587    }
588
589    #[test]
590    fn test_marginalize_3d() {
591        let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
592
593        let expected = Scs::new([90., 117., 144.], [3]).unwrap();
594        assert_eq!(scs.marginalize(&[Axis(0), Axis(2)]).unwrap(), expected);
595        assert_eq!(scs.marginalize(&[Axis(2), Axis(0)]).unwrap(), expected);
596    }
597
598    #[test]
599    fn test_marginalize_too_many_axes() {
600        let scs = Scs::from_range(0..9, [3, 3]).unwrap();
601
602        assert_eq!(
603            scs.marginalize(&[Axis(0), Axis(1)]),
604            Err(MarginalizationError::TooManyAxes {
605                axes: 2,
606                dimensions: 2
607            }),
608        );
609    }
610
611    #[test]
612    fn test_marginalize_duplicate_axis() {
613        let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
614
615        assert_eq!(
616            scs.marginalize(&[Axis(1), Axis(1)]),
617            Err(MarginalizationError::DuplicateAxis { axis: 1 }),
618        );
619    }
620
621    #[test]
622    fn test_marginalize_axis_out_ouf_bounds() {
623        let scs = Scs::from_range(0..9, [3, 3]).unwrap();
624
625        assert_eq!(
626            scs.marginalize(&[Axis(2)]),
627            Err(MarginalizationError::AxisOutOfBounds {
628                axis: 2,
629                dimensions: 2
630            }),
631        );
632    }
633
634    #[test]
635    fn test_project_7_to_3() {
636        let scs = Scs::from_range(0..7, 7).unwrap();
637        let projected = scs.project(3).unwrap();
638        let expected = Scs::new([2.333333, 7.0, 11.666667], 3).unwrap();
639        assert_approx_eq!(projected, expected, epsilon = 1e-6);
640    }
641
642    #[test]
643    fn test_project_7_to_7_is_identity() {
644        let scs = Scs::from_range(0..7, 7).unwrap();
645        let projected = scs.project(7).unwrap();
646        assert_eq!(scs, projected);
647    }
648
649    #[test]
650    fn test_project_7_to_8_is_error() {
651        let scs = Scs::from_range(0..7, 7).unwrap();
652        let result = scs.project(8);
653
654        assert!(matches!(
655            result,
656            Err(ProjectionError::InvalidProjection { .. })
657        ));
658    }
659
660    #[test]
661    fn test_project_7_to_0_is_error() {
662        let scs = Scs::from_range(0..7, 7).unwrap();
663        let result = scs.project(0);
664
665        assert!(matches!(result, Err(ProjectionError::Zero)));
666    }
667
668    #[test]
669    fn test_project_3x3_to_2x2() {
670        let scs = Scs::from_range(0..9, [3, 3]).unwrap();
671        let projected = scs.project([2, 2]).unwrap();
672        let expected = Scs::new([3.0, 6.0, 12.0, 15.0], [2, 2]).unwrap();
673        assert_approx_eq!(projected, expected, epsilon = 1e-6);
674    }
675}