sfs_core/spectrum/
folded.rs

1use std::{cmp::Ordering, marker::PhantomData};
2
3use crate::Array;
4
5use super::{Scs, Spectrum, State};
6
7/// A folded spectrum.
8#[derive(Debug, PartialEq)]
9pub struct Folded<S: State> {
10    array: Array<Option<f64>>, // "lower" triangle is None,
11    state: PhantomData<S>,
12}
13
14impl<S: State> Folded<S> {
15    pub(super) fn from_spectrum(spectrum: &Spectrum<S>) -> Self {
16        let n = spectrum.elements();
17        let total_count = spectrum.shape().iter().sum::<usize>() - spectrum.shape().len();
18
19        // In general, this point divides the folding line. Since we are folding onto the "upper"
20        // part of the array, we want to fold anything "below" it onto something "above" it.
21        let mid_count = total_count / 2;
22
23        // The spectrum may or may not have a "diagonal", i.e. a hyperplane that falls exactly on
24        // the midpoint. If such a diagonal exists, we need to handle it as a special case when
25        // folding below.
26        //
27        // For example, in 1D a spectrum with five elements has a "diagonal", marked X:
28        // [-, -, X, -, -]
29        // Whereas on with four elements would not.
30        //
31        // In two dimensions, e.g. three-by-three elements has a diagonal:
32        // [-, -, X]
33        // [-, X, -]
34        // [X, -, -]
35        // whereas two-by-three would not. On the other hand, two-by-four has a diagonal:
36        // [-, -, X, -]
37        // [-, X, -, -]
38        //
39        // Note that even-ploidy data should always have a diagonal, whereas odd-ploidy data
40        // may or may not.
41        let has_diagonal = total_count % 2 == 0;
42
43        // Note that we cannot use the algorithm below in-place, since the reverse iterator
44        // may reach elements that have already been folded, which causes bugs. Hence we fold
45        // into a zero-initialised copy.
46        let mut array = Array::from_element(None, spectrum.shape().clone());
47
48        // We iterate over indices rather than values since we have to mutate on the array
49        // while looking at it from both directions.
50        (0..n).zip((0..n).rev()).for_each(|(i, rev_i)| {
51            let count = spectrum.shape().index_sum_from_flat_unchecked(i);
52
53            let src = spectrum.array.as_slice();
54            let dst = array.as_mut_slice();
55
56            match (count.cmp(&mid_count), has_diagonal) {
57                (Ordering::Less, _) | (Ordering::Equal, false) => {
58                    // We are in the upper part of the spectrum that should be folded onto.
59                    dst[i] = Some(src[i] + src[rev_i]);
60                }
61                (Ordering::Equal, true) => {
62                    // We are on a diagonal, which must be handled as a special case:
63                    // there are apparently different opinions on what the most correct
64                    // thing to do is. This adopts the same strategy as e.g. in dadi.
65                    dst[i] = Some(0.5 * src[i] + 0.5 * src[rev_i]);
66                }
67                (Ordering::Greater, _) => {
68                    // We are in the lower part of the spectrum to be filled with None;
69                    dst[i] = None;
70                }
71            }
72        });
73
74        Self {
75            array,
76            state: PhantomData,
77        }
78    }
79
80    /// Returns an unfolded spectrum based on the folded spectrum, filling the folded elements with
81    /// the provided element.
82    pub fn into_spectrum(&self, fill: f64) -> Spectrum<S> {
83        let data = Vec::from_iter(self.array.iter().map(|x| x.unwrap_or(fill)));
84        let shape = self.array.shape().clone();
85        let array = Array::new_unchecked(data, shape);
86
87        Scs::from(array).into_state_unchecked()
88    }
89}
90
91impl<S: State> Clone for Folded<S> {
92    fn clone(&self) -> Self {
93        Self {
94            array: self.array.clone(),
95            state: PhantomData,
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_fold_4() {
106        let scs = Scs::from_range(0..4, 4).unwrap();
107        let expected = Scs::new([3., 3., 0., 0.], 4).unwrap();
108        assert_eq!(scs.fold().into_spectrum(0.0), expected);
109    }
110
111    #[test]
112    fn test_fold_5() {
113        let scs = Scs::from_range(0..5, 5).unwrap();
114        let expected = Scs::new([4., 4., 2., -1., -1.], 5).unwrap();
115        assert_eq!(scs.fold().into_spectrum(-1.), expected);
116    }
117
118    #[test]
119    fn test_fold_3x3() {
120        let scs = Scs::from_range(0..9, [3, 3]).unwrap();
121
122        #[rustfmt::skip]
123        let expected = Scs::new(
124            [
125                8., 8., 4.,
126                8., 4., 0.,
127                4., 0., 0.,
128            ],
129            [3, 3]
130        ).unwrap();
131
132        assert_eq!(scs.fold().into_spectrum(0.0), expected);
133    }
134
135    #[test]
136    fn test_fold_2x4() {
137        let scs = Scs::from_range(0..8, [2, 4]).unwrap();
138
139        #[rustfmt::skip]
140        let expected = Scs::new(
141            [
142                7., 7.,            3.5, f64::INFINITY,
143                7., 3.5, f64::INFINITY, f64::INFINITY,
144            ],
145            [2, 4]
146        ).unwrap();
147
148        assert_eq!(scs.fold().into_spectrum(f64::INFINITY), expected);
149    }
150
151    #[test]
152    fn test_fold_3x4() {
153        let scs = Scs::from_range(0..12, [3, 4]).unwrap();
154
155        #[rustfmt::skip]
156        let expected = Scs::new(
157            [
158                11., 11., 11., 0.,
159                11., 11.,  0., 0.,
160                11.,  0.,  0., 0.,
161            ],
162            [3, 4]
163        ).unwrap();
164
165        assert_eq!(scs.fold().into_spectrum(0.), expected);
166    }
167
168    #[test]
169    fn test_fold_3x7() {
170        let scs = Scs::from_range(0..21, [3, 7]).unwrap();
171
172        #[rustfmt::skip]
173        let expected = Scs::new(
174            [
175                20., 20., 20., 20., 10., 0., 0.,
176                20., 20., 20., 10.,  0., 0., 0.,
177                20., 20., 10.,  0.,  0., 0., 0.,
178            ],
179            [3, 7]
180        ).unwrap();
181
182        assert_eq!(scs.fold().into_spectrum(0.0), expected);
183    }
184
185    #[test]
186    fn test_fold_2x2x2() {
187        let scs = Scs::from_range(0..8, [2, 2, 2]).unwrap();
188
189        #[rustfmt::skip]
190        let expected = Scs::new(
191            [
192                 7.,  7.,
193                 7., -1.,
194                
195                 7., -1.,
196                -1., -1.,
197            ],
198            [2, 2, 2]
199        ).unwrap();
200
201        assert_eq!(scs.fold().into_spectrum(-1.0), expected);
202    }
203
204    #[test]
205    fn test_fold_2x3x2() {
206        let scs = Scs::from_range(0..12, [2, 3, 2]).unwrap();
207
208        #[rustfmt::skip]
209        let expected = Scs::new(
210            [
211                11., 11.,  
212                11.,  5.5,
213                5.5,  0.,
214                
215                11.,  5.5,
216                 5.5, 0.,
217                 0.,  0.,
218            ],
219            [2, 3, 2]
220        ).unwrap();
221
222        assert_eq!(scs.fold().into_spectrum(0.0), expected);
223    }
224
225    #[test]
226    fn test_fold_3x3x3() {
227        let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
228
229        #[rustfmt::skip]
230        let expected = Scs::new(
231        [
232                26., 26., 26.,
233                26., 26., 13.,
234                26., 13.,  0.,
235                
236                26., 26., 13.,
237                26., 13.,  0.,
238                13.,  0.,  0.,
239
240                26., 13.,  0.,
241                13.,  0.,  0.,
242                 0.,  0.,  0.,
243            ],
244        [3, 3, 3]
245        ).unwrap();
246
247        assert_eq!(scs.fold().into_spectrum(0.0), expected);
248    }
249}