sfs_core/spectrum/
project.rs

1use std::fmt;
2
3use crate::{array::Shape, utils::hypergeometric_pmf};
4
5use super::{Count, Scs};
6
7#[derive(Clone, Debug, Eq, PartialEq)]
8pub struct PartialProjection {
9    project_to: Count,
10    to_buf: Count,
11}
12
13impl PartialProjection {
14    pub fn from_shape<S>(project_to: S) -> Result<Self, ProjectionError>
15    where
16        S: Into<Shape>,
17    {
18        Count::try_from_shape(project_to.into())
19            .ok_or(ProjectionError::Zero)
20            .map(Self::new)
21    }
22
23    pub fn new<C>(project_to: C) -> Self
24    where
25        C: Into<Count>,
26    {
27        let project_to = project_to.into();
28
29        Self {
30            to_buf: Count::from_zeros(project_to.dimensions()),
31            project_to,
32        }
33    }
34
35    pub fn project_to(&self) -> &Count {
36        &self.project_to
37    }
38
39    pub fn project_unchecked<'a>(
40        &'a mut self,
41        project_from: &'a Count,
42        from: &'a Count,
43    ) -> Projected<'a> {
44        self.to_buf.set_zero();
45
46        Projected::new_unchecked(project_from, &self.project_to, from, &mut self.to_buf)
47    }
48}
49
50#[derive(Clone, Debug, Eq, PartialEq)]
51pub struct Projection {
52    project_from: Count,
53    inner: PartialProjection,
54}
55
56impl Projection {
57    pub fn from_shapes<S>(project_from: S, project_to: S) -> Result<Self, ProjectionError>
58    where
59        S: Into<Shape>,
60    {
61        match (
62            Count::try_from_shape(project_from.into()),
63            Count::try_from_shape(project_to.into()),
64        ) {
65            (Some(project_from), Some(project_to)) => Self::new(project_from, project_to),
66            (None, None) | (None, Some(_)) | (Some(_), None) => Err(ProjectionError::Zero),
67        }
68    }
69
70    pub fn new<C>(project_from: C, project_to: C) -> Result<Self, ProjectionError>
71    where
72        C: Into<Count>,
73    {
74        let from = project_from.into();
75        let to = project_to.into();
76
77        if from.dimensions() == to.dimensions() {
78            if let Some(dimension) = from
79                .iter()
80                .zip(to.iter())
81                .enumerate()
82                .find_map(|(i, (from, to))| (from < to).then_some(i))
83            {
84                Err(ProjectionError::InvalidProjection {
85                    dimension,
86                    from: from[dimension],
87                    to: to[dimension],
88                })
89            } else {
90                Ok(Self::new_unchecked(from, to))
91            }
92        } else if from.dimensions() == 0 {
93            Err(ProjectionError::Empty)
94        } else {
95            Err(ProjectionError::UnequalDimensions {
96                from: from.dimensions(),
97                to: to.dimensions(),
98            })
99        }
100    }
101
102    pub fn new_unchecked<C>(project_from: C, project_to: C) -> Self
103    where
104        C: Into<Count>,
105    {
106        Self {
107            project_from: project_from.into(),
108            inner: PartialProjection::new(project_to),
109        }
110    }
111
112    pub fn project_unchecked<'a>(&'a mut self, from: &'a Count) -> Projected<'a> {
113        self.inner.project_unchecked(&self.project_from, from)
114    }
115}
116
117#[derive(Debug)]
118pub struct Projected<'a> {
119    iter: ProjectIter<'a>,
120    weight: f64,
121}
122
123impl<'a> Projected<'a> {
124    pub fn add_unchecked(self, to: &mut Scs) {
125        to.inner_mut()
126            .iter_mut()
127            .zip(self.iter)
128            .for_each(|(to, projected)| *to += projected * self.weight);
129    }
130
131    fn new_unchecked(
132        project_from: &'a Count,
133        project_to: &'a Count,
134        from: &'a Count,
135        to: &'a mut Count,
136    ) -> Self {
137        Self {
138            iter: ProjectIter::new_unchecked(project_from, project_to, from, to),
139            weight: 1.0,
140        }
141    }
142
143    pub fn into_weighted(mut self, weight: f64) -> Self {
144        self.weight = weight;
145        self
146    }
147}
148
149#[derive(Debug)]
150struct ProjectIter<'a> {
151    project_from: &'a Count,
152    project_to: &'a Count,
153    from: &'a Count,
154    to: &'a mut Count,
155    index: usize,
156}
157
158impl<'a> ProjectIter<'a> {
159    fn dimensions(&self) -> usize {
160        self.to.len()
161    }
162
163    fn impl_next_rec(&mut self, axis: usize) -> Option<<Self as Iterator>::Item> {
164        if self.index == 0 {
165            self.index += 1;
166            return Some(self.project_value());
167        };
168
169        self.to[axis] += 1;
170        if self.to[axis] <= self.project_to[axis] {
171            self.index += 1;
172            Some(self.project_value())
173        } else if axis > 0 {
174            self.to[axis] = 0;
175            self.impl_next_rec(axis - 1)
176        } else {
177            None
178        }
179    }
180
181    fn new_unchecked(
182        project_from: &'a Count,
183        project_to: &'a Count,
184        from: &'a Count,
185        to: &'a mut Count,
186    ) -> Self {
187        Self {
188            project_from,
189            project_to,
190            from,
191            to,
192            index: 0,
193        }
194    }
195
196    fn project_value(&self) -> f64 {
197        self.project_from
198            .iter()
199            .zip(self.from.iter())
200            .zip(self.project_to.iter())
201            .zip(self.to.iter())
202            .map(|(((&size, &successes), &draws), &observed)| {
203                hypergeometric_pmf(size as u64, successes as u64, draws as u64, observed as u64)
204            })
205            .fold(1.0, |joint, probability| joint * probability)
206    }
207}
208
209impl<'a> Iterator for ProjectIter<'a> {
210    type Item = f64;
211
212    fn next(&mut self) -> Option<Self::Item> {
213        self.impl_next_rec(self.dimensions() - 1)
214    }
215}
216
217/// An error associated with a projection.
218#[derive(Debug)]
219pub enum ProjectionError {
220    /// Empty projection.
221    Empty,
222    /// Projection attempts to project from a smaller to a larger shape.
223    InvalidProjection {
224        /// Dimension in which projection fails.
225        dimension: usize,
226        /// Shape from which to project from.
227        from: usize,
228        /// Shape from which to project to.
229        to: usize,
230    },
231    /// Projection attempts to project from one dimension to another.
232    UnequalDimensions {
233        /// Dimension from which to project from.
234        from: usize,
235        /// Dimension from which to project to.
236        to: usize,
237    },
238    /// Projection in the zero dimension.
239    Zero,
240}
241
242impl fmt::Display for ProjectionError {
243    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244        match self {
245            ProjectionError::Empty => f.write_str("cannot project empty counts"),
246            ProjectionError::InvalidProjection {
247                dimension,
248                from,
249                to,
250            } => {
251                write!(
252                    f,
253                    "cannot project from count {from} to count {to} in dimension {dimension}"
254                )
255            }
256            ProjectionError::UnequalDimensions { from, to } => {
257                write!(
258                    f,
259                    "cannot project from one number of dimensions ({from}) to another ({to})"
260                )
261            }
262            ProjectionError::Zero => f.write_str("cannot project to or from shape zero"),
263        }
264    }
265}
266
267impl std::error::Error for ProjectionError {}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_projection_errors() {
275        assert!(matches!(
276            Projection::new(vec![2, 3], vec![1]),
277            Err(ProjectionError::UnequalDimensions { .. })
278        ));
279
280        assert!(matches!(
281            Projection::new([2, 3], [3, 2]),
282            Err(ProjectionError::InvalidProjection { .. })
283        ))
284    }
285
286    macro_rules! assert_project_to {
287        ($projection:ident from [$($from:literal),+] is [$($expected:literal),+]) => {
288            assert_approx_eq!(
289                $projection
290                    .project_unchecked(&Count::from([$($from),+]))
291                    .iter
292                    .collect::<Vec<_>>(),
293                vec![$($expected),+],
294                epsilon = 1e-6
295            );
296        };
297    }
298
299    #[test]
300    fn test_project_6_to_2() {
301        let mut projection = Projection::new_unchecked(Count::from(6), Count::from(2));
302
303        assert_project_to!(projection from [0] is [1.000000, 0.000000, 0.000000]);
304        assert_project_to!(projection from [1] is [0.666666, 0.333333, 0.000000]);
305        assert_project_to!(projection from [2] is [0.400000, 0.533333, 0.066667]);
306        assert_project_to!(projection from [3] is [0.200000, 0.600000, 0.200000]);
307        assert_project_to!(projection from [4] is [0.066667, 0.533333, 0.400000]);
308        assert_project_to!(projection from [5] is [0.000000, 0.333333, 0.666666]);
309        assert_project_to!(projection from [6] is [0.000000, 0.000000, 1.000000]);
310    }
311
312    #[test]
313    fn test_project_2x2_to_1x1() {
314        let mut projection = Projection::new_unchecked(Count::from([2, 2]), Count::from([1, 1]));
315
316        assert_project_to!(projection from [0, 0] is [1.00, 0.00, 0.00, 0.00]);
317        assert_project_to!(projection from [0, 1] is [0.50, 0.50, 0.00, 0.00]);
318        assert_project_to!(projection from [0, 2] is [0.00, 1.00, 0.00, 0.00]);
319        assert_project_to!(projection from [1, 0] is [0.50, 0.00, 0.50, 0.00]);
320        assert_project_to!(projection from [1, 1] is [0.25, 0.25, 0.25, 0.25]);
321        assert_project_to!(projection from [1, 2] is [0.00, 0.50, 0.00, 0.50]);
322        assert_project_to!(projection from [2, 0] is [0.00, 0.00, 1.00, 0.00]);
323        assert_project_to!(projection from [2, 1] is [0.00, 0.00, 0.50, 0.50]);
324        assert_project_to!(projection from [2, 2] is [0.00, 0.00, 0.00, 1.00]);
325    }
326}