winsfs_core/sfs.rs
1//! Multi-dimensional site frequency spectra ("SFS").
2//!
3//! The central type is the [`SfsBase`] struct, which represents an SFS with a dimensionality
4//! that may or may not be known at compile time, and which may or may not be normalised to
5//! probability scale. Type aliases [`Sfs`], [`USfs`], [`DynSfs`], and [`DynUSfs`] are exposed
6//! for convenience.
7
8use std::{
9 cmp::Ordering,
10 error::Error,
11 fmt::{self, Write as _},
12 marker::PhantomData,
13 ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
14 slice,
15};
16
17use crate::ArrayExt;
18
19pub mod generics;
20use generics::{ConstShape, DynShape, Norm, Normalisation, Shape, Unnorm};
21
22pub mod io;
23
24pub mod iter;
25use iter::Indices;
26
27mod em;
28
29const NORMALISATION_TOLERANCE: f64 = 10. * f64::EPSILON;
30
31/// Creates an unnormalised 1D SFS.
32///
33/// This is mainly intended for readability in doc-tests, but may also be useful elsewhere.
34///
35/// # Examples
36///
37/// Create SFS by repeating an element:
38///
39/// ```
40/// use winsfs_core::sfs1d;
41/// let sfs = sfs1d![0.1; 10];
42/// assert!(sfs.iter().all(|&x| x == 0.1));
43/// ```
44///
45/// Create SFS from a list of elements:
46///
47/// ```
48/// use winsfs_core::sfs1d;
49/// let sfs = sfs1d![0.1, 0.2, 0.3];
50/// assert_eq!(sfs[[0]], 0.1);
51/// assert_eq!(sfs[[1]], 0.2);
52/// assert_eq!(sfs[[2]], 0.3);
53/// ```
54#[macro_export]
55macro_rules! sfs1d {
56 ($elem:expr; $n:expr) => {
57 $crate::sfs::USfs::from_elem($elem, [$n])
58 };
59 ($($x:expr),+ $(,)?) => {
60 $crate::sfs::USfs::from_vec(vec![$($x),+])
61 };
62}
63
64/// Creates an unnormalised 2D SFS.
65///
66/// This is mainly intended for readability in doc-tests, but may also be useful elsewhere.
67///
68/// # Examples
69///
70/// ```
71/// use winsfs_core::sfs2d;
72/// let sfs = sfs2d![
73/// [0.1, 0.2, 0.3],
74/// [0.4, 0.5, 0.6],
75/// [0.7, 0.8, 0.9],
76/// ];
77/// assert_eq!(sfs[[0, 0]], 0.1);
78/// assert_eq!(sfs[[1, 0]], 0.4);
79/// assert_eq!(sfs[[2, 0]], 0.7);
80/// ```
81#[macro_export]
82macro_rules! sfs2d {
83 ($([$($x:literal),+ $(,)?]),+ $(,)?) => {{
84 let (cols, vec) = $crate::matrix!($([$($x),+]),+);
85 let shape = [cols.len(), cols[0]];
86 $crate::sfs::SfsBase::from_vec_shape(vec, shape).unwrap()
87 }};
88}
89
90/// An multi-dimensional site frequency spectrum ("SFS").
91///
92/// Elements are stored in row-major order: the last index varies the fastest.
93///
94/// The number of dimensions of the SFS may either be known at compile-time or run-time,
95/// and this is governed by the [`Shape`] trait. Moreover, the SFS may or may not be normalised
96/// to probability scale, and this is controlled by the [`Normalisation`] trait.
97/// See also the [`Sfs`], [`USfs`], [`DynSfs`], and [`DynUSfs`] type aliases.
98#[derive(Clone, Debug, PartialEq)]
99// TODO: Replace normalisation with const enum once these are permitted in const generics,
100// see github.com/rust-lang/rust/issues/95174
101pub struct SfsBase<S: Shape, N: Normalisation> {
102 values: Vec<f64>,
103 pub(crate) shape: S,
104 pub(crate) strides: S,
105 norm: PhantomData<N>,
106}
107
108/// A normalised SFS with shape known at compile-time.
109pub type Sfs<const D: usize> = SfsBase<ConstShape<D>, Norm>;
110
111/// An unnormalised SFS with shape known at compile-time.
112pub type USfs<const D: usize> = SfsBase<ConstShape<D>, Unnorm>;
113
114/// A normalised SFS with shape known at run-time.
115pub type DynSfs = SfsBase<DynShape, Norm>;
116
117/// An unnormalised SFS with shape known at run-time.
118pub type DynUSfs = SfsBase<DynShape, Unnorm>;
119
120impl<S: Shape, N: Normalisation> SfsBase<S, N> {
121 /// Returns the values of the SFS as a flat, row-major slice.
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// use winsfs_core::sfs2d;
127 /// let sfs = sfs2d![
128 /// [0., 1., 2.],
129 /// [3., 4., 5.],
130 /// ];
131 /// assert_eq!(sfs.as_slice(), [0., 1., 2., 3., 4., 5.]);
132 /// ```
133 #[inline]
134 pub fn as_slice(&self) -> &[f64] {
135 &self.values
136 }
137
138 /// Returns a folded version of the SFS.
139 ///
140 /// Folding is useful when the spectrum has not been properly polarised, so that there is
141 /// no meaningful distinction between having 0 and 2N (in the diploid case) variants at a site.
142 /// The folding operation collapses these indistinguishable bins by adding the value from the
143 /// lower part of the spectrum onto the upper, and setting the lower value to zero.
144 ///
145 /// Note that we adopt the convention that on the "diagonal" of the SFS, where there is less of
146 /// a convention on what is the correct way of folding, the arithmetic mean of the candidates is
147 /// used. The examples below illustrate this.
148 ///
149 /// # Examples
150 ///
151 /// Folding in 1D:
152 ///
153 /// ```
154 /// use winsfs_core::sfs1d;
155 /// let sfs = sfs1d![5., 2., 3., 10., 1.];
156 /// assert_eq!(sfs.fold(), sfs1d![6., 12., 3., 0., 0.]);
157 /// ```
158 ///
159 /// Folding in 2D (square input):
160 ///
161 /// ```
162 /// use winsfs_core::sfs2d;
163 /// let sfs = sfs2d![
164 /// [4., 2., 10.],
165 /// [0., 3., 4.],
166 /// [7., 2., 1.],
167 /// ];
168 /// let expected = sfs2d![
169 /// [5., 4., 8.5],
170 /// [4., 3., 0.],
171 /// [8.5, 0., 0.],
172 /// ];
173 /// assert_eq!(sfs.fold(), expected);
174 /// ```
175 ///
176 /// Folding in 2D (non-square input):
177 ///
178 /// ```
179 /// use winsfs_core::sfs2d;
180 /// let sfs = sfs2d![
181 /// [4., 2., 10.],
182 /// [0., 3., 4.],
183 /// ];
184 /// let expected = sfs2d![
185 /// [8., 5., 0.],
186 /// [10., 0., 0.],
187 /// ];
188 /// assert_eq!(sfs.fold(), expected);
189 /// ```
190 pub fn fold(&self) -> Self {
191 let n = self.values.len();
192 let total_count = self.shape.iter().sum::<usize>() - self.shape.len();
193
194 // In general, this point divides the folding line. Since we are folding onto the "upper"
195 // part of the array, we want to fold anything "below" it onto something "above" it.
196 let mid_count = total_count / 2;
197
198 // The spectrum may or may not have a "diagonal", i.e. a hyperplane that falls exactly on
199 // the midpoint. If such a diagonal exists, we need to handle it as a special case when
200 // folding below.
201 //
202 // For example, in 1D a spectrum with five elements has a "diagonal", marked X:
203 // [-, -, X, -, -]
204 // Whereas on with four elements would not.
205 //
206 // In two dimensions, e.g. three-by-three elements has a diagonal:
207 // [-, -, X]
208 // [-, X, -]
209 // [X, -, -]
210 // whereas two-by-three would not. On the other hand, two-by-four has a diagonal:
211 // [-, -, X, -]
212 // [-, X, -, -]
213 //
214 // Note that even-ploidy data should always have a diagonal, whereas odd-ploidy data
215 // may or may not.
216 let has_diagonal = total_count % 2 == 0;
217
218 // Note that we cannot use the algorithm below in-place, since the reverse iterator
219 // may reach elements that have already been folded, which causes bugs. Hence we fold
220 // into a zero-initialised copy.
221 let mut folded = Self::new_unchecked(vec![0.0; n], self.shape.clone());
222
223 // We iterate over indices rather than values since we have to mutate on the array
224 // while looking at it from both directions.
225 (0..n).zip((0..n).rev()).for_each(|(i, rev_i)| {
226 let count = compute_index_sum_unchecked(i, n, self.shape.as_ref());
227
228 match (count.cmp(&mid_count), has_diagonal) {
229 (Ordering::Less, _) | (Ordering::Equal, false) => {
230 // We are in the upper part of the spectrum that should be folded onto.
231 folded.values[i] = self.values[i] + self.values[rev_i];
232 }
233 (Ordering::Equal, true) => {
234 // We are on a diagonal, which must be handled as a special case:
235 // there are apparently different opinions on what the most correct
236 // thing to do is. This adopts the same strategy as e.g. in dadi.
237 folded.values[i] = 0.5 * self.values[i] + 0.5 * self.values[rev_i];
238 }
239 (Ordering::Greater, _) => (),
240 }
241 });
242
243 folded
244 }
245
246 /// Returns a string containing a flat, row-major represention of the SFS.
247 ///
248 /// # Examples
249 ///
250 /// ```
251 /// use winsfs_core::sfs1d;
252 /// let sfs = sfs1d![0.0, 0.1, 0.2];
253 /// assert_eq!(sfs.format_flat(" ", 1), "0.0 0.1 0.2");
254 /// ```
255 ///
256 /// ```
257 /// use winsfs_core::sfs2d;
258 /// let sfs = sfs2d![[0.01, 0.12], [0.23, 0.34]];
259 /// assert_eq!(sfs.format_flat(",", 2), "0.01,0.12,0.23,0.34");
260 /// ```
261 pub fn format_flat(&self, sep: &str, precision: usize) -> String {
262 if let Some(first) = self.values.first() {
263 let cap = self.values.len() * (precision + 3);
264 let mut init = String::with_capacity(cap);
265 write!(init, "{first:.precision$}").unwrap();
266 // init.push_str(&format!("{:.precision$}", first));
267
268 self.iter().skip(1).fold(init, |mut s, x| {
269 s.push_str(sep);
270 write!(s, "{x:.precision$}").unwrap();
271 s
272 })
273 } else {
274 String::new()
275 }
276 }
277
278 /// Returns a value at an index in the SFS.
279 ///
280 /// If the index is out of bounds, `None` is returned.
281 ///
282 /// # Examples
283 ///
284 /// ```
285 /// use winsfs_core::sfs1d;
286 /// let sfs = sfs1d![0.0, 0.1, 0.2];
287 /// assert_eq!(sfs.get(&[0]), Some(&0.0));
288 /// assert_eq!(sfs.get(&[1]), Some(&0.1));
289 /// assert_eq!(sfs.get(&[2]), Some(&0.2));
290 /// assert_eq!(sfs.get(&[3]), None);
291 /// ```
292 ///
293 /// ```
294 /// use winsfs_core::sfs2d;
295 /// let sfs = sfs2d![[0.0, 0.1, 0.2], [0.3, 0.4, 0.5], [0.6, 0.7, 0.8]];
296 /// assert_eq!(sfs.get(&[0, 0]), Some(&0.0));
297 /// assert_eq!(sfs.get(&[1, 2]), Some(&0.5));
298 /// assert_eq!(sfs.get(&[3, 0]), None);
299 /// ```
300 #[inline]
301 pub fn get(&self, index: &S) -> Option<&f64> {
302 self.values.get(compute_flat(index, &self.shape)?)
303 }
304
305 /// Returns a normalised SFS, consuming `self`.
306 ///
307 /// This works purely on the type level, and does not modify the actual values in the SFS.
308 /// If the SFS is not already normalised, an error is returned. To modify the SFS to become
309 /// normalised, see [`Sfs::normalise`].
310 ///
311 /// # Examples
312 ///
313 /// An unnormalised SFS with values summing to one can be turned into a normalised SFS:
314 ///
315 /// ```
316 /// use winsfs_core::{sfs1d, sfs::{Sfs, USfs}};
317 /// let sfs: USfs<1> = sfs1d![0.2; 5];
318 /// let sfs: Sfs<1> = sfs.into_normalised().unwrap();
319 /// ```
320 ///
321 /// Otherwise, an unnormalised SFS cannot be normalised SFS using this method:
322 ///
323 /// ```
324 /// use winsfs_core::{sfs1d, sfs::USfs};
325 /// let sfs: USfs<1> = sfs1d![2.; 5];
326 /// assert!(sfs.into_normalised().is_err());
327 /// ```
328 ///
329 /// Use [`Sfs::normalise`] instead.
330 #[inline]
331 pub fn into_normalised(self) -> Result<SfsBase<S, Norm>, NormError> {
332 let sum = self.sum();
333
334 if (sum - 1.).abs() <= NORMALISATION_TOLERANCE {
335 Ok(self.into_normalised_unchecked())
336 } else {
337 Err(NormError { sum })
338 }
339 }
340
341 #[inline]
342 fn into_normalised_unchecked(self) -> SfsBase<S, Norm> {
343 SfsBase {
344 values: self.values,
345 shape: self.shape,
346 strides: self.strides,
347 norm: PhantomData,
348 }
349 }
350
351 /// Returns an unnormalised SFS, consuming `self`.
352 ///
353 /// This works purely on the type level, and does not modify the actual values in the SFS.
354 ///
355 /// # Examples
356 ///
357 /// ```
358 /// use winsfs_core::sfs::{Sfs, USfs};
359 /// let sfs: Sfs<1> = Sfs::uniform([7]);
360 /// let sfs: USfs<1> = sfs.into_unnormalised();
361 /// ```
362 #[inline]
363 pub fn into_unnormalised(self) -> SfsBase<S, Unnorm> {
364 SfsBase {
365 values: self.values,
366 shape: self.shape,
367 strides: self.strides,
368 norm: PhantomData,
369 }
370 }
371
372 /// Returns an iterator over the elements in the SFS in row-major order.
373 ///
374 /// # Examples
375 ///
376 /// ```
377 /// use winsfs_core::sfs2d;
378 /// let sfs = sfs2d![
379 /// [0., 1., 2.],
380 /// [3., 4., 5.],
381 /// [6., 7., 8.],
382 /// ];
383 /// let expected = (0..9).map(|x| x as f64);
384 /// assert!(sfs.iter().zip(expected).all(|(&x, y)| x == y));
385 /// ```
386 #[inline]
387 pub fn iter(&self) -> slice::Iter<'_, f64> {
388 self.values.iter()
389 }
390
391 /// Creates a new SFS.
392 #[inline]
393 fn new_unchecked(values: Vec<f64>, shape: S) -> Self {
394 let strides = shape.strides();
395
396 Self {
397 values,
398 shape,
399 strides,
400 norm: PhantomData,
401 }
402 }
403
404 /// Returns an unnormalised SFS scaled by some constant, consuming `self`.
405 ///
406 /// # Examples
407 ///
408 /// ```
409 /// use winsfs_core::sfs1d;
410 /// assert_eq!(
411 /// sfs1d![0., 1., 2., 3., 4.].scale(10.),
412 /// sfs1d![0., 10., 20., 30., 40.],
413 /// );
414 /// ```
415 #[inline]
416 #[must_use = "returns scaled SFS, doesn't modify in-place"]
417 pub fn scale(mut self, scale: f64) -> SfsBase<S, Unnorm> {
418 self.values.iter_mut().for_each(|x| *x *= scale);
419
420 self.into_unnormalised()
421 }
422
423 /// Returns the SFS shape.
424 ///
425 /// # Examples
426 ///
427 /// ```
428 /// use winsfs_core::sfs2d;
429 /// let sfs = sfs2d![
430 /// [0., 1., 2.],
431 /// [3., 4., 5.],
432 /// ];
433 /// assert_eq!(sfs.shape(), &[2, 3]);
434 /// ```
435 pub fn shape(&self) -> &S {
436 &self.shape
437 }
438
439 /// Returns the sum of values in the SFS.
440 #[inline]
441 fn sum(&self) -> f64 {
442 self.iter().sum()
443 }
444}
445
446impl<const D: usize, N: Normalisation> SfsBase<ConstShape<D>, N> {
447 /// Returns an iterator over the sample frequencies of the SFS in row-major order.
448 ///
449 /// Note that this is *not* the contents of SFS, but the frequencies corresponding
450 /// to the indices. See [`Sfs::iter`] for an iterator over the SFS values themselves.
451 ///
452 /// # Examples
453 ///
454 /// ```
455 /// use winsfs_core::sfs::Sfs;
456 /// let sfs = Sfs::uniform([2, 3]);
457 /// let mut iter = sfs.frequencies();
458 /// assert_eq!(iter.next(), Some([0., 0.]));
459 /// assert_eq!(iter.next(), Some([0., 0.5]));
460 /// assert_eq!(iter.next(), Some([0., 1.]));
461 /// assert_eq!(iter.next(), Some([1., 0.]));
462 /// assert_eq!(iter.next(), Some([1., 0.5]));
463 /// assert_eq!(iter.next(), Some([1., 1.]));
464 /// assert!(iter.next().is_none());
465 /// ```
466 pub fn frequencies(&self) -> impl Iterator<Item = [f64; D]> {
467 let n_arr = self.shape.map(|n| n - 1);
468 self.indices()
469 .map(move |idx_arr| idx_arr.array_zip(n_arr).map(|(i, n)| i as f64 / n as f64))
470 }
471
472 /// Returns an iterator over the indices in the SFS in row-major order.
473 ///
474 /// # Examples
475 ///
476 /// ```
477 /// use winsfs_core::sfs::Sfs;
478 /// let sfs = Sfs::uniform([2, 3]);
479 /// let mut iter = sfs.indices();
480 /// assert_eq!(iter.next(), Some([0, 0]));
481 /// assert_eq!(iter.next(), Some([0, 1]));
482 /// assert_eq!(iter.next(), Some([0, 2]));
483 /// assert_eq!(iter.next(), Some([1, 0]));
484 /// assert_eq!(iter.next(), Some([1, 1]));
485 /// assert_eq!(iter.next(), Some([1, 2]));
486 /// assert!(iter.next().is_none());
487 /// ```
488 pub fn indices(&self) -> Indices<ConstShape<D>> {
489 Indices::from_shape(self.shape)
490 }
491}
492
493impl<S: Shape> SfsBase<S, Norm> {
494 /// Creates a new, normalised, and uniform SFS.
495 ///
496 /// # Examples
497 ///
498 /// ```
499 /// use winsfs_core::sfs::Sfs;
500 /// let sfs = Sfs::uniform([2, 5]);
501 /// assert!(sfs.iter().all(|&x| x == 0.1));
502 /// ```
503 pub fn uniform(shape: S) -> SfsBase<S, Norm> {
504 let n: usize = shape.iter().product();
505
506 let elem = 1.0 / n as f64;
507
508 SfsBase::new_unchecked(vec![elem; n], shape)
509 }
510}
511
512impl<S: Shape> SfsBase<S, Unnorm> {
513 /// Returns the a mutable reference values of the SFS as a flat, row-major slice.
514 ///
515 /// # Examples
516 ///
517 /// ```
518 /// use winsfs_core::sfs2d;
519 /// let mut sfs = sfs2d![
520 /// [0., 1., 2.],
521 /// [3., 4., 5.],
522 /// ];
523 /// assert_eq!(sfs.as_slice(), [0., 1., 2., 3., 4., 5.]);
524 /// sfs.as_mut_slice()[0] = 100.;
525 /// assert_eq!(sfs.as_slice(), [100., 1., 2., 3., 4., 5.]);
526 /// ```
527 #[inline]
528 pub fn as_mut_slice(&mut self) -> &mut [f64] {
529 &mut self.values
530 }
531
532 /// Creates a new, unnormalised SFS by repeating a single value.
533 ///
534 /// See also [`Sfs::uniform`] to create a normalised SFS with uniform values.
535 ///
536 /// # Examples
537 ///
538 /// ```
539 /// use winsfs_core::sfs::USfs;
540 /// let sfs = USfs::from_elem(0.1, [7, 5]);
541 /// assert_eq!(sfs.shape(), &[7, 5]);
542 /// assert!(sfs.iter().all(|&x| x == 0.1));
543 /// ```
544 pub fn from_elem(elem: f64, shape: S) -> Self {
545 let n = shape.iter().product();
546
547 Self::new_unchecked(vec![elem; n], shape)
548 }
549
550 /// Creates a new, unnormalised SFS from an iterator.
551 ///
552 /// # Examples
553 ///
554 /// ```
555 /// use winsfs_core::sfs::USfs;
556 /// let iter = (0..9).map(|x| x as f64);
557 /// let sfs = USfs::from_iter_shape(iter, [3, 3]).expect("shape didn't fit iterator!");
558 /// assert_eq!(sfs[[1, 2]], 5.0);
559 /// ```
560 pub fn from_iter_shape<I>(iter: I, shape: S) -> Result<Self, ShapeError<S>>
561 where
562 I: IntoIterator<Item = f64>,
563 {
564 Self::from_vec_shape(iter.into_iter().collect(), shape)
565 }
566
567 /// Creates a new, unnormalised SFS from a vector.
568 ///
569 /// # Examples
570 ///
571 /// ```
572 /// use winsfs_core::sfs::USfs;
573 /// let vec: Vec<f64> = (0..9).map(|x| x as f64).collect();
574 /// let sfs = USfs::from_vec_shape(vec, [3, 3]).expect("shape didn't fit vector!");
575 /// assert_eq!(sfs[[2, 0]], 6.0);
576 /// ```
577 pub fn from_vec_shape(vec: Vec<f64>, shape: S) -> Result<Self, ShapeError<S>> {
578 let n: usize = shape.iter().product();
579
580 match vec.len() == n {
581 true => Ok(Self::new_unchecked(vec, shape)),
582 false => Err(ShapeError::new(n, shape)),
583 }
584 }
585
586 /// Returns a mutable reference to a value at an index in the SFS.
587 ///
588 /// If the index is out of bounds, `None` is returned.
589 ///
590 /// # Examples
591 ///
592 /// ```
593 /// use winsfs_core::sfs1d;
594 /// let mut sfs = sfs1d![0.0, 0.1, 0.2];
595 /// assert_eq!(sfs[[0]], 0.0);
596 /// if let Some(v) = sfs.get_mut(&[0]) {
597 /// *v = 0.5;
598 /// }
599 /// assert_eq!(sfs[[0]], 0.5);
600 /// ```
601 ///
602 /// ```
603 /// use winsfs_core::sfs2d;
604 /// let mut sfs = sfs2d![[0.0, 0.1, 0.2], [0.3, 0.4, 0.5], [0.6, 0.7, 0.8]];
605 /// assert_eq!(sfs[[0, 0]], 0.0);
606 /// if let Some(v) = sfs.get_mut(&[0, 0]) {
607 /// *v = 0.5;
608 /// }
609 /// assert_eq!(sfs[[0, 0]], 0.5);
610 /// ```
611 #[inline]
612 pub fn get_mut(&mut self, index: &S) -> Option<&mut f64> {
613 self.values.get_mut(compute_flat(index, &self.shape)?)
614 }
615
616 /// Returns an iterator over mutable references to the elements in the SFS in row-major order.
617 #[inline]
618 pub fn iter_mut(&mut self) -> slice::IterMut<'_, f64> {
619 self.values.iter_mut()
620 }
621
622 /// Returns a normalised SFS, consuming `self`.
623 ///
624 /// The values in the SFS are modified to sum to one.
625 ///
626 /// # Examples
627 ///
628 /// ```
629 /// use winsfs_core::{sfs1d, sfs::{Sfs, USfs}};
630 /// let sfs: USfs<1> = sfs1d![0., 1., 2., 3., 4.];
631 /// let sfs: Sfs<1> = sfs.normalise();
632 /// assert_eq!(sfs[[1]], 0.1);
633 /// ```
634 #[inline]
635 #[must_use = "returns normalised SFS, doesn't modify in-place"]
636 pub fn normalise(mut self) -> SfsBase<S, Norm> {
637 let sum = self.sum();
638
639 self.iter_mut().for_each(|x| *x /= sum);
640
641 self.into_normalised_unchecked()
642 }
643
644 /// Creates a new, unnnormalised SFS with all entries set to zero.
645 ///
646 /// # Examples
647 ///
648 /// ```
649 /// use winsfs_core::sfs::USfs;
650 /// let sfs = USfs::zeros([2, 5]);
651 /// assert!(sfs.iter().all(|&x| x == 0.0));
652 /// ```
653 pub fn zeros(shape: S) -> Self {
654 Self::from_elem(0.0, shape)
655 }
656}
657
658impl SfsBase<ConstShape<1>, Unnorm> {
659 /// Creates a new SFS from a vector.
660 ///
661 /// # Examples
662 ///
663 /// ```
664 /// use winsfs_core::sfs::USfs;
665 /// let sfs = USfs::from_vec(vec![0., 1., 2.]);
666 /// assert_eq!(sfs.shape(), &[3]);
667 /// assert_eq!(sfs[[1]], 1.);
668 /// ```
669 pub fn from_vec(values: Vec<f64>) -> Self {
670 let shape = [values.len()];
671
672 Self::new_unchecked(values, shape)
673 }
674}
675
676impl SfsBase<ConstShape<2>, Norm> {
677 /// Returns the f2-statistic.
678 ///
679 /// # Examples
680 ///
681 /// ```
682 /// use winsfs_core::sfs2d;
683 /// let sfs = sfs2d![
684 /// [1., 0., 0.],
685 /// [0., 1., 0.],
686 /// [0., 0., 1.],
687 /// ].normalise();
688 /// assert_eq!(sfs.f2(), 0.);
689 /// ```
690 pub fn f2(&self) -> f64 {
691 self.iter()
692 .zip(self.frequencies())
693 .map(|(v, [f_i, f_j])| v * (f_i - f_j).powi(2))
694 .sum()
695 }
696}
697
698macro_rules! impl_op {
699 ($trait:ident, $method:ident, $assign_trait:ident, $assign_method:ident) => {
700 impl<S: Shape, N: Normalisation> $assign_trait<&SfsBase<S, N>> for SfsBase<S, Unnorm> {
701 #[inline]
702 fn $assign_method(&mut self, rhs: &SfsBase<S, N>) {
703 assert_eq!(self.shape, rhs.shape);
704
705 self.iter_mut()
706 .zip(rhs.iter())
707 .for_each(|(x, rhs)| x.$assign_method(rhs));
708 }
709 }
710
711 impl<S: Shape, N: Normalisation> $assign_trait<SfsBase<S, N>> for SfsBase<S, Unnorm> {
712 #[inline]
713 fn $assign_method(&mut self, rhs: SfsBase<S, N>) {
714 self.$assign_method(&rhs);
715 }
716 }
717
718 impl<S: Shape, N: Normalisation, M: Normalisation> $trait<SfsBase<S, M>> for SfsBase<S, N> {
719 type Output = SfsBase<S, Unnorm>;
720
721 #[inline]
722 fn $method(self, rhs: SfsBase<S, M>) -> Self::Output {
723 let mut sfs = self.into_unnormalised();
724 sfs.$assign_method(&rhs);
725 sfs
726 }
727 }
728
729 impl<S: Shape, N: Normalisation, M: Normalisation> $trait<&SfsBase<S, M>>
730 for SfsBase<S, N>
731 {
732 type Output = SfsBase<S, Unnorm>;
733
734 #[inline]
735 fn $method(self, rhs: &SfsBase<S, M>) -> Self::Output {
736 let mut sfs = self.into_unnormalised();
737 sfs.$assign_method(rhs);
738 sfs
739 }
740 }
741 };
742}
743impl_op!(Add, add, AddAssign, add_assign);
744impl_op!(Sub, sub, SubAssign, sub_assign);
745
746impl<S: Shape, N: Normalisation> Index<S> for SfsBase<S, N> {
747 type Output = f64;
748
749 #[inline]
750 fn index(&self, index: S) -> &Self::Output {
751 self.get(&index).unwrap()
752 }
753}
754
755impl<S: Shape> IndexMut<S> for SfsBase<S, Unnorm> {
756 #[inline]
757 fn index_mut(&mut self, index: S) -> &mut Self::Output {
758 self.get_mut(&index).unwrap()
759 }
760}
761
762impl<const D: usize, N: Normalisation> From<SfsBase<ConstShape<D>, N>> for SfsBase<DynShape, N> {
763 fn from(sfs: SfsBase<ConstShape<D>, N>) -> Self {
764 SfsBase {
765 values: sfs.values,
766 shape: sfs.shape.into(),
767 strides: sfs.strides.into(),
768 norm: PhantomData,
769 }
770 }
771}
772
773impl<const D: usize, N: Normalisation> TryFrom<SfsBase<DynShape, N>> for SfsBase<ConstShape<D>, N> {
774 type Error = SfsBase<DynShape, N>;
775
776 fn try_from(sfs: SfsBase<DynShape, N>) -> Result<Self, Self::Error> {
777 match (
778 <[usize; D]>::try_from(&sfs.shape[..]),
779 <[usize; D]>::try_from(&sfs.strides[..]),
780 ) {
781 (Ok(shape), Ok(strides)) => Ok(SfsBase {
782 values: sfs.values,
783 shape,
784 strides,
785 norm: PhantomData,
786 }),
787 (Err(_), Err(_)) => Err(sfs),
788 (Ok(_), Err(_)) | (Err(_), Ok(_)) => {
789 unreachable!("conversion of dyn shape and strides succeeds or fails together")
790 }
791 }
792 }
793}
794
795/// An error associated with SFS construction using invalid shape.
796#[derive(Clone, Copy, Debug)]
797pub struct ShapeError<S: Shape> {
798 n: usize,
799 shape: S,
800}
801
802impl<S: Shape> ShapeError<S> {
803 fn new(n: usize, shape: S) -> Self {
804 Self { n, shape }
805 }
806}
807
808impl<S: Shape> fmt::Display for ShapeError<S> {
809 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
810 let shape_fmt = self
811 .shape
812 .iter()
813 .map(|x| x.to_string())
814 .collect::<Vec<_>>()
815 .join("/");
816 let n = self.n;
817 let d = self.shape.as_ref().len();
818
819 write!(
820 f,
821 "cannot create {d}D SFS with shape {shape_fmt} from {n} elements"
822 )
823 }
824}
825
826impl<S: Shape> Error for ShapeError<S> {}
827
828/// An error associated with normalised SFS construction using unnormalised input.
829#[derive(Clone, Copy, Debug)]
830pub struct NormError {
831 sum: f64,
832}
833
834impl fmt::Display for NormError {
835 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
836 write!(
837 f,
838 "cannot create normalised SFS using values summing to {}",
839 self.sum
840 )
841 }
842}
843
844impl Error for NormError {}
845
846fn compute_flat<S: Shape>(index: &S, shape: &S) -> Option<usize> {
847 assert_eq!(index.len(), shape.len());
848
849 for i in 1..index.len() {
850 if index.as_ref()[i] >= shape.as_ref()[i] {
851 return None;
852 }
853 }
854 Some(compute_flat_unchecked(index, shape))
855}
856
857fn compute_flat_unchecked<S: Shape>(index: &S, shape: &S) -> usize {
858 let mut flat = index.as_ref()[0];
859 for i in 1..index.len() {
860 flat *= shape.as_ref()[i];
861 flat += index.as_ref()[i];
862 }
863 flat
864}
865
866fn compute_index_sum_unchecked(mut flat: usize, mut n: usize, shape: &[usize]) -> usize {
867 let mut sum = 0;
868 for v in shape {
869 n /= v;
870 sum += flat / n;
871 flat %= n;
872 }
873 sum
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879
880 #[test]
881 fn test_index_1d() {
882 let sfs = sfs1d![0., 1., 2., 3., 4., 5.];
883 assert_eq!(sfs.get(&[0]), Some(&0.));
884 assert_eq!(sfs.get(&[2]), Some(&2.));
885 assert_eq!(sfs.get(&[5]), Some(&5.));
886 assert_eq!(sfs.get(&[6]), None);
887 }
888
889 #[test]
890 fn test_index_2d() {
891 let sfs = sfs2d![[0., 1., 2.], [3., 4., 5.]];
892 assert_eq!(sfs.get(&[0, 0]), Some(&0.));
893 assert_eq!(sfs.get(&[1, 0]), Some(&3.));
894 assert_eq!(sfs.get(&[1, 1]), Some(&4.));
895 assert_eq!(sfs.get(&[1, 2]), Some(&5.));
896 assert_eq!(sfs.get(&[2, 0]), None);
897 assert_eq!(sfs.get(&[0, 3]), None);
898 }
899
900 #[test]
901 fn test_f2() {
902 #[rustfmt::skip]
903 let sfs = sfs2d![
904 [0., 1., 2.],
905 [3., 4., 5.]
906 ].normalise();
907 assert!((sfs.f2() - 0.4166667).abs() < 1e-6);
908 }
909
910 #[test]
911 fn test_sfs_addition() {
912 let mut lhs = sfs1d![0., 1., 2.];
913 let rhs = sfs1d![5., 6., 7.];
914 let sum = sfs1d![5., 7., 9.];
915
916 assert_eq!(lhs.clone() + rhs.clone(), sum);
917 assert_eq!(lhs.clone() + &rhs, sum);
918
919 lhs += rhs.clone();
920 assert_eq!(lhs, sum);
921 lhs += &rhs;
922 assert_eq!(lhs, sum + rhs);
923 }
924
925 #[test]
926 fn test_sfs_subtraction() {
927 let mut lhs = sfs1d![5., 6., 7.];
928 let rhs = sfs1d![0., 1., 2.];
929 let sub = sfs1d![5., 5., 5.];
930
931 assert_eq!(lhs.clone() - rhs.clone(), sub);
932 assert_eq!(lhs.clone() - &rhs, sub);
933
934 lhs -= rhs.clone();
935 assert_eq!(lhs, sub);
936 lhs -= &rhs;
937 assert_eq!(lhs, sub - rhs);
938 }
939
940 #[test]
941 fn test_fold_4() {
942 let sfs = sfs1d![0., 1., 2., 3.];
943
944 assert_eq!(sfs.fold(), sfs1d![3., 3., 0., 0.],);
945 }
946
947 #[test]
948 fn test_fold_5() {
949 let sfs = sfs1d![0., 1., 2., 3., 4.];
950
951 assert_eq!(sfs.fold(), sfs1d![4., 4., 2., 0., 0.],);
952 }
953
954 #[test]
955 fn test_fold_3x3() {
956 #[rustfmt::skip]
957 let sfs = sfs2d![
958 [0., 1., 2.],
959 [3., 4., 5.],
960 [6., 7., 8.],
961 ];
962
963 #[rustfmt::skip]
964 let expected = sfs2d![
965 [8., 8., 4.],
966 [8., 4., 0.],
967 [4., 0., 0.],
968 ];
969
970 assert_eq!(sfs.fold(), expected);
971 }
972
973 #[test]
974 fn test_fold_2x4() {
975 #[rustfmt::skip]
976 let sfs = sfs2d![
977 [0., 1., 2., 3.],
978 [4., 5., 6., 7.],
979 ];
980
981 #[rustfmt::skip]
982 let expected = sfs2d![
983 [7., 7., 3.5, 0.],
984 [7., 3.5, 0., 0.],
985 ];
986
987 assert_eq!(sfs.fold(), expected);
988 }
989
990 #[test]
991 fn test_fold_3x4() {
992 #[rustfmt::skip]
993 let sfs = sfs2d![
994 [0., 1., 2., 3.],
995 [4., 5., 6., 7.],
996 [8., 9., 10., 11.],
997 ];
998
999 #[rustfmt::skip]
1000 let expected = sfs2d![
1001 [11., 11., 11., 0.],
1002 [11., 11., 0., 0.],
1003 [11., 0., 0., 0.],
1004 ];
1005
1006 assert_eq!(sfs.fold(), expected);
1007 }
1008
1009 #[test]
1010 fn test_fold_3x7() {
1011 #[rustfmt::skip]
1012 let sfs = sfs2d![
1013 [ 0., 1., 2., 3., 4., 5., 6.],
1014 [ 7., 8., 9., 10., 11., 12., 13.],
1015 [14., 15., 16., 17., 18., 19., 20.],
1016 ];
1017
1018 #[rustfmt::skip]
1019 let expected = sfs2d![
1020 [20., 20., 20., 20., 10., 0., 0.],
1021 [20., 20., 20., 10., 0., 0., 0.],
1022 [20., 20., 10., 0., 0., 0., 0.],
1023 ];
1024
1025 assert_eq!(sfs.fold(), expected);
1026 }
1027
1028 #[test]
1029 fn test_fold_2x2x2() {
1030 let sfs = USfs::from_iter_shape((0..8).map(|x| x as f64), [2, 2, 2]).unwrap();
1031
1032 #[rustfmt::skip]
1033 let expected = USfs::from_vec_shape(
1034 vec![
1035 7., 7.,
1036 7., 0.,
1037
1038 7., 0.,
1039 0., 0.,
1040 ],
1041 [2, 2, 2]
1042 ).unwrap();
1043
1044 assert_eq!(sfs.fold(), expected);
1045 }
1046
1047 #[test]
1048 fn test_fold_2x3x2() {
1049 let sfs = USfs::from_iter_shape((0..12).map(|x| x as f64), [2, 3, 2]).unwrap();
1050
1051 #[rustfmt::skip]
1052 let expected = USfs::from_vec_shape(
1053 vec![
1054 11., 11.,
1055 11., 5.5,
1056 5.5, 0.,
1057
1058 11., 5.5,
1059 5.5, 0.,
1060 0., 0.,
1061 ],
1062 [2, 3, 2]
1063 ).unwrap();
1064
1065 assert_eq!(sfs.fold(), expected);
1066 }
1067
1068 #[test]
1069 fn test_fold_3x3x3() {
1070 let sfs = USfs::from_iter_shape((0..27).map(|x| x as f64), [3, 3, 3]).unwrap();
1071
1072 #[rustfmt::skip]
1073 let expected = USfs::from_vec_shape(
1074 vec![
1075 26., 26., 26.,
1076 26., 26., 13.,
1077 26., 13., 0.,
1078
1079 26., 26., 13.,
1080 26., 13., 0.,
1081 13., 0., 0.,
1082
1083 26., 13., 0.,
1084 13., 0., 0.,
1085 0., 0., 0.,
1086 ],
1087 [3, 3, 3]
1088 ).unwrap();
1089
1090 assert_eq!(sfs.fold(), expected);
1091 }
1092}