1use std::{
4 fmt,
5 marker::PhantomData,
6 ops::{AddAssign, Index, IndexMut, Range},
7};
8
9mod count;
10pub use count::Count;
11
12pub mod io;
13
14pub mod iter;
15use iter::FrequenciesIter;
16
17mod folded;
18pub use folded::Folded;
19
20pub(crate) mod project;
21use project::Projection;
22pub use project::ProjectionError;
23
24mod stat;
25pub use stat::StatisticError;
26
27use crate::array::{Array, Axis, Shape, ShapeError};
28
29mod seal {
30 #![deny(missing_docs)]
31 pub trait Sealed {}
32}
33use seal::Sealed;
34
35pub trait State: Sealed {
39 #[doc(hidden)]
40 fn debug_name() -> &'static str;
41}
42
43#[derive(Copy, Clone, Debug)]
47pub struct Frequencies;
48impl Sealed for Frequencies {}
49impl State for Frequencies {
50 fn debug_name() -> &'static str {
51 "Sfs"
52 }
53}
54
55#[derive(Copy, Clone, Debug, Eq, PartialEq)]
59pub struct Counts;
60impl Sealed for Counts {}
61impl State for Counts {
62 fn debug_name() -> &'static str {
63 "Scs"
64 }
65}
66
67pub type Sfs = Spectrum<Frequencies>;
69
70pub type Scs = Spectrum<Counts>;
72
73#[derive(PartialEq)]
77pub struct Spectrum<S: State> {
78 array: Array<f64>,
79 state: PhantomData<S>,
80}
81
82impl<S: State> Spectrum<S> {
83 pub fn dimensions(&self) -> usize {
85 self.array.dimensions()
86 }
87
88 pub fn elements(&self) -> usize {
90 self.array.elements()
91 }
92
93 pub fn fold(&self) -> Folded<S> {
95 Folded::from_spectrum(self)
96 }
97
98 pub fn inner(&self) -> &Array<f64> {
100 &self.array
101 }
102
103 pub fn into_normalized(mut self) -> Sfs {
105 self.normalize();
106 self.into_state_unchecked()
107 }
108
109 fn into_state_unchecked<R: State>(self) -> Spectrum<R> {
110 Spectrum {
111 array: self.array,
112 state: PhantomData,
113 }
114 }
115
116 pub fn iter_frequencies(&self) -> FrequenciesIter<'_> {
122 FrequenciesIter::new(self)
123 }
124
125 pub fn king(&self) -> Result<f64, StatisticError> {
133 stat::King::from_spectrum(self)
134 .map(|x| x.0)
135 .map_err(Into::into)
136 }
137
138 pub fn marginalize(&self, axes: &[Axis]) -> Result<Self, MarginalizationError> {
144 if let Some(duplicate) = axes.iter().enumerate().find_map(|(i, axis)| {
145 axes.get(i + 1..)
146 .and_then(|slice| slice.contains(axis).then_some(axis))
147 }) {
148 return Err(MarginalizationError::DuplicateAxis { axis: duplicate.0 });
149 };
150
151 if let Some(out_of_bounds) = axes.iter().find(|axis| axis.0 >= self.dimensions()) {
152 return Err(MarginalizationError::AxisOutOfBounds {
153 axis: out_of_bounds.0,
154 dimensions: self.dimensions(),
155 });
156 };
157
158 if axes.len() >= self.dimensions() {
159 return Err(MarginalizationError::TooManyAxes {
160 axes: axes.len(),
161 dimensions: self.dimensions(),
162 });
163 }
164
165 let is_sorted = axes.windows(2).all(|w| w[0] <= w[1]);
166 if is_sorted {
167 Ok(self.marginalize_unchecked(axes))
168 } else {
169 let mut axes = axes.to_vec();
170 axes.sort();
171 Ok(self.marginalize_unchecked(&axes))
172 }
173 }
174
175 fn marginalize_axis(&self, axis: Axis) -> Self {
176 Scs::from(self.array.sum(axis)).into_state_unchecked()
177 }
178
179 fn marginalize_unchecked(&self, axes: &[Axis]) -> Self {
180 let mut spectrum = self.clone();
181
182 axes.iter()
185 .enumerate()
186 .map(|(removed, original)| Axis(original.0 - removed))
187 .for_each(|axis| {
188 spectrum = spectrum.marginalize_axis(axis);
189 });
190
191 spectrum
192 }
193
194 pub fn normalize(&mut self) {
199 let sum = self.sum();
200 self.array.iter_mut().for_each(|x| *x /= sum);
201 }
202
203 pub fn pi(&self) -> Result<f64, StatisticError> {
209 stat::Pi::from_spectrum(self)
210 .map(|x| x.0)
211 .map_err(Into::into)
212 }
213
214 pub fn pi_xy(&self) -> Result<f64, StatisticError> {
223 stat::PiXY::from_spectrum(self)
224 .map(|x| x.0)
225 .map_err(Into::into)
226 }
227
228 pub fn project<T>(&self, project_to: T) -> Result<Self, ProjectionError>
238 where
239 T: Into<Shape>,
240 {
241 let project_to = project_to.into();
242 let mut projection = Projection::from_shapes(self.shape().clone(), project_to.clone())?;
243 let mut new = Scs::from_zeros(project_to);
244
245 for (&weight, from) in self.array.iter().zip(self.array.iter_indices().map(Count)) {
246 projection
247 .project_unchecked(&from)
248 .into_weighted(weight)
249 .add_unchecked(&mut new);
250 }
251
252 Ok(new.into_state_unchecked())
253 }
254
255 pub fn r0(&self) -> Result<f64, StatisticError> {
263 stat::R0::from_spectrum(self)
264 .map(|x| x.0)
265 .map_err(Into::into)
266 }
267
268 pub fn r1(&self) -> Result<f64, StatisticError> {
276 stat::R1::from_spectrum(self)
277 .map(|x| x.0)
278 .map_err(Into::into)
279 }
280
281 pub fn shape(&self) -> &Shape {
283 self.array.shape()
284 }
285
286 pub fn sum(&self) -> f64 {
288 self.array.iter().sum::<f64>()
289 }
290
291 pub fn theta_watterson(&self) -> Result<f64, StatisticError> {
297 stat::Theta::<stat::theta::Watterson>::from_spectrum(self)
298 .map(|x| x.0)
299 .map_err(Into::into)
300 }
301}
302
303impl Scs {
304 pub fn d_fu_li(&self) -> Result<f64, StatisticError> {
312 stat::D::<stat::d::FuLi>::from_scs(self)
313 .map(|x| x.0)
314 .map_err(Into::into)
315 }
316
317 pub fn d_tajima(&self) -> Result<f64, StatisticError> {
325 stat::D::<stat::d::Tajima>::from_scs(self)
326 .map(|x| x.0)
327 .map_err(Into::into)
328 }
329
330 pub fn from_range<S>(range: Range<usize>, shape: S) -> Result<Self, ShapeError>
338 where
339 S: Into<Shape>,
340 {
341 Array::from_iter(range.map(|v| v as f64), shape).map(Self::from)
342 }
343
344 pub fn from_vec<T>(vec: T) -> Self
346 where
347 T: Into<Vec<f64>>,
348 {
349 let vec = vec.into();
350 let shape = vec.len();
351 Self::new(vec, shape).unwrap()
352 }
353
354 pub fn from_zeros<S>(shape: S) -> Self
356 where
357 S: Into<Shape>,
358 {
359 Self::from(Array::from_zeros(shape))
360 }
361
362 pub fn inner_mut(&mut self) -> &mut Array<f64> {
364 &mut self.array
365 }
366
367 pub fn new<D, S>(data: D, shape: S) -> Result<Self, ShapeError>
373 where
374 D: Into<Vec<f64>>,
375 S: Into<Shape>,
376 {
377 Array::new(data, shape).map(Self::from)
378 }
379
380 pub fn segregating_sites(&self) -> f64 {
382 let n = self.elements();
383
384 self.array.iter().take(n - 1).skip(1).sum()
385 }
386}
387
388impl Sfs {
389 pub fn f2(&self) -> Result<f64, StatisticError> {
397 stat::F2::from_sfs(self).map(|x| x.0).map_err(Into::into)
398 }
399
400 pub fn f3(&self) -> Result<f64, StatisticError> {
412 stat::F3::from_sfs(self).map(|x| x.0).map_err(Into::into)
413 }
414
415 pub fn f4(&self) -> Result<f64, StatisticError> {
427 stat::F4::from_sfs(self).map(|x| x.0).map_err(Into::into)
428 }
429
430 pub fn fst(&self) -> Result<f64, StatisticError> {
438 stat::Fst::from_sfs(self).map(|x| x.0).map_err(Into::into)
439 }
440}
441
442impl<S: State> Clone for Spectrum<S> {
443 fn clone(&self) -> Self {
444 Self {
445 array: self.array.clone(),
446 state: PhantomData,
447 }
448 }
449}
450
451impl<S: State> fmt::Debug for Spectrum<S> {
452 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453 f.debug_struct(S::debug_name())
454 .field("array", &self.array)
455 .finish()
456 }
457}
458
459impl AddAssign<&Count> for Scs {
460 fn add_assign(&mut self, count: &Count) {
461 self[count] += 1.0;
462 }
463}
464
465impl From<Array<f64>> for Scs {
466 fn from(array: Array<f64>) -> Self {
467 Self {
468 array,
469 state: PhantomData,
470 }
471 }
472}
473
474impl<I, S: State> Index<I> for Spectrum<S>
475where
476 I: AsRef<[usize]>,
477{
478 type Output = f64;
479
480 fn index(&self, index: I) -> &Self::Output {
481 self.array.index(index)
482 }
483}
484
485impl<I, S: State> IndexMut<I> for Spectrum<S>
486where
487 I: AsRef<[usize]>,
488{
489 fn index_mut(&mut self, index: I) -> &mut Self::Output {
490 self.array.index_mut(index)
491 }
492}
493
494#[derive(Debug, Eq, PartialEq)]
496pub enum MarginalizationError {
497 DuplicateAxis {
499 axis: usize,
501 },
502 AxisOutOfBounds {
504 axis: usize,
506 dimensions: usize,
508 },
509 TooManyAxes {
511 axes: usize,
513 dimensions: usize,
515 },
516}
517
518impl fmt::Display for MarginalizationError {
519 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
520 match self {
521 MarginalizationError::DuplicateAxis { axis } => {
522 write!(f, "cannot marginalize with duplicate axis {axis}")
523 }
524 MarginalizationError::AxisOutOfBounds { axis, dimensions } => write!(
525 f,
526 "cannot marginalize axis {axis} in spectrum with {dimensions} dimensions"
527 ),
528 MarginalizationError::TooManyAxes { axes, dimensions } => write!(
529 f,
530 "cannot marginalize a total of {axes} axes in spectrum with {dimensions} dimensions"
531 ),
532 }
533 }
534}
535
536impl std::error::Error for MarginalizationError {}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 use crate::approx::ApproxEq;
543
544 impl<S: State> ApproxEq for Spectrum<S> {
545 const DEFAULT_EPSILON: Self::Epsilon = <f64 as ApproxEq>::DEFAULT_EPSILON;
546
547 type Epsilon = <f64 as ApproxEq>::Epsilon;
548
549 fn approx_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
550 self.array.approx_eq(&other.array, epsilon)
551 }
552 }
553
554 #[test]
555 fn test_marginalize_axis_2d() {
556 let scs = Scs::from_range(0..9, [3, 3]).unwrap();
557
558 assert_eq!(
559 scs.marginalize_axis(Axis(0)),
560 Scs::new([9., 12., 15.], 3).unwrap()
561 );
562
563 assert_eq!(
564 scs.marginalize_axis(Axis(1)),
565 Scs::new([3., 12., 21.], 3).unwrap()
566 );
567 }
568
569 #[test]
570 fn test_marginalize_axis_3d() {
571 let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
572
573 assert_eq!(
574 scs.marginalize_axis(Axis(0)),
575 Scs::new([27., 30., 33., 36., 39., 42., 45., 48., 51.], [3, 3]).unwrap()
576 );
577
578 assert_eq!(
579 scs.marginalize_axis(Axis(1)),
580 Scs::new([9., 12., 15., 36., 39., 42., 63., 66., 69.], [3, 3]).unwrap()
581 );
582
583 assert_eq!(
584 scs.marginalize_axis(Axis(2)),
585 Scs::new([3., 12., 21., 30., 39., 48., 57., 66., 75.], [3, 3]).unwrap()
586 );
587 }
588
589 #[test]
590 fn test_marginalize_3d() {
591 let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
592
593 let expected = Scs::new([90., 117., 144.], [3]).unwrap();
594 assert_eq!(scs.marginalize(&[Axis(0), Axis(2)]).unwrap(), expected);
595 assert_eq!(scs.marginalize(&[Axis(2), Axis(0)]).unwrap(), expected);
596 }
597
598 #[test]
599 fn test_marginalize_too_many_axes() {
600 let scs = Scs::from_range(0..9, [3, 3]).unwrap();
601
602 assert_eq!(
603 scs.marginalize(&[Axis(0), Axis(1)]),
604 Err(MarginalizationError::TooManyAxes {
605 axes: 2,
606 dimensions: 2
607 }),
608 );
609 }
610
611 #[test]
612 fn test_marginalize_duplicate_axis() {
613 let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
614
615 assert_eq!(
616 scs.marginalize(&[Axis(1), Axis(1)]),
617 Err(MarginalizationError::DuplicateAxis { axis: 1 }),
618 );
619 }
620
621 #[test]
622 fn test_marginalize_axis_out_ouf_bounds() {
623 let scs = Scs::from_range(0..9, [3, 3]).unwrap();
624
625 assert_eq!(
626 scs.marginalize(&[Axis(2)]),
627 Err(MarginalizationError::AxisOutOfBounds {
628 axis: 2,
629 dimensions: 2
630 }),
631 );
632 }
633
634 #[test]
635 fn test_project_7_to_3() {
636 let scs = Scs::from_range(0..7, 7).unwrap();
637 let projected = scs.project(3).unwrap();
638 let expected = Scs::new([2.333333, 7.0, 11.666667], 3).unwrap();
639 assert_approx_eq!(projected, expected, epsilon = 1e-6);
640 }
641
642 #[test]
643 fn test_project_7_to_7_is_identity() {
644 let scs = Scs::from_range(0..7, 7).unwrap();
645 let projected = scs.project(7).unwrap();
646 assert_eq!(scs, projected);
647 }
648
649 #[test]
650 fn test_project_7_to_8_is_error() {
651 let scs = Scs::from_range(0..7, 7).unwrap();
652 let result = scs.project(8);
653
654 assert!(matches!(
655 result,
656 Err(ProjectionError::InvalidProjection { .. })
657 ));
658 }
659
660 #[test]
661 fn test_project_7_to_0_is_error() {
662 let scs = Scs::from_range(0..7, 7).unwrap();
663 let result = scs.project(0);
664
665 assert!(matches!(result, Err(ProjectionError::Zero)));
666 }
667
668 #[test]
669 fn test_project_3x3_to_2x2() {
670 let scs = Scs::from_range(0..9, [3, 3]).unwrap();
671 let projected = scs.project([2, 2]).unwrap();
672 let expected = Scs::new([3.0, 6.0, 12.0, 15.0], [2, 2]).unwrap();
673 assert_approx_eq!(projected, expected, epsilon = 1e-6);
674 }
675}