1use std::fmt;
2
3use crate::{array::Shape, utils::hypergeometric_pmf};
4
5use super::{Count, Scs};
6
7#[derive(Clone, Debug, Eq, PartialEq)]
8pub struct PartialProjection {
9 project_to: Count,
10 to_buf: Count,
11}
12
13impl PartialProjection {
14 pub fn from_shape<S>(project_to: S) -> Result<Self, ProjectionError>
15 where
16 S: Into<Shape>,
17 {
18 Count::try_from_shape(project_to.into())
19 .ok_or(ProjectionError::Zero)
20 .map(Self::new)
21 }
22
23 pub fn new<C>(project_to: C) -> Self
24 where
25 C: Into<Count>,
26 {
27 let project_to = project_to.into();
28
29 Self {
30 to_buf: Count::from_zeros(project_to.dimensions()),
31 project_to,
32 }
33 }
34
35 pub fn project_to(&self) -> &Count {
36 &self.project_to
37 }
38
39 pub fn project_unchecked<'a>(
40 &'a mut self,
41 project_from: &'a Count,
42 from: &'a Count,
43 ) -> Projected<'a> {
44 self.to_buf.set_zero();
45
46 Projected::new_unchecked(project_from, &self.project_to, from, &mut self.to_buf)
47 }
48}
49
50#[derive(Clone, Debug, Eq, PartialEq)]
51pub struct Projection {
52 project_from: Count,
53 inner: PartialProjection,
54}
55
56impl Projection {
57 pub fn from_shapes<S>(project_from: S, project_to: S) -> Result<Self, ProjectionError>
58 where
59 S: Into<Shape>,
60 {
61 match (
62 Count::try_from_shape(project_from.into()),
63 Count::try_from_shape(project_to.into()),
64 ) {
65 (Some(project_from), Some(project_to)) => Self::new(project_from, project_to),
66 (None, None) | (None, Some(_)) | (Some(_), None) => Err(ProjectionError::Zero),
67 }
68 }
69
70 pub fn new<C>(project_from: C, project_to: C) -> Result<Self, ProjectionError>
71 where
72 C: Into<Count>,
73 {
74 let from = project_from.into();
75 let to = project_to.into();
76
77 if from.dimensions() == to.dimensions() {
78 if let Some(dimension) = from
79 .iter()
80 .zip(to.iter())
81 .enumerate()
82 .find_map(|(i, (from, to))| (from < to).then_some(i))
83 {
84 Err(ProjectionError::InvalidProjection {
85 dimension,
86 from: from[dimension],
87 to: to[dimension],
88 })
89 } else {
90 Ok(Self::new_unchecked(from, to))
91 }
92 } else if from.dimensions() == 0 {
93 Err(ProjectionError::Empty)
94 } else {
95 Err(ProjectionError::UnequalDimensions {
96 from: from.dimensions(),
97 to: to.dimensions(),
98 })
99 }
100 }
101
102 pub fn new_unchecked<C>(project_from: C, project_to: C) -> Self
103 where
104 C: Into<Count>,
105 {
106 Self {
107 project_from: project_from.into(),
108 inner: PartialProjection::new(project_to),
109 }
110 }
111
112 pub fn project_unchecked<'a>(&'a mut self, from: &'a Count) -> Projected<'a> {
113 self.inner.project_unchecked(&self.project_from, from)
114 }
115}
116
117#[derive(Debug)]
118pub struct Projected<'a> {
119 iter: ProjectIter<'a>,
120 weight: f64,
121}
122
123impl<'a> Projected<'a> {
124 pub fn add_unchecked(self, to: &mut Scs) {
125 to.inner_mut()
126 .iter_mut()
127 .zip(self.iter)
128 .for_each(|(to, projected)| *to += projected * self.weight);
129 }
130
131 fn new_unchecked(
132 project_from: &'a Count,
133 project_to: &'a Count,
134 from: &'a Count,
135 to: &'a mut Count,
136 ) -> Self {
137 Self {
138 iter: ProjectIter::new_unchecked(project_from, project_to, from, to),
139 weight: 1.0,
140 }
141 }
142
143 pub fn into_weighted(mut self, weight: f64) -> Self {
144 self.weight = weight;
145 self
146 }
147}
148
149#[derive(Debug)]
150struct ProjectIter<'a> {
151 project_from: &'a Count,
152 project_to: &'a Count,
153 from: &'a Count,
154 to: &'a mut Count,
155 index: usize,
156}
157
158impl<'a> ProjectIter<'a> {
159 fn dimensions(&self) -> usize {
160 self.to.len()
161 }
162
163 fn impl_next_rec(&mut self, axis: usize) -> Option<<Self as Iterator>::Item> {
164 if self.index == 0 {
165 self.index += 1;
166 return Some(self.project_value());
167 };
168
169 self.to[axis] += 1;
170 if self.to[axis] <= self.project_to[axis] {
171 self.index += 1;
172 Some(self.project_value())
173 } else if axis > 0 {
174 self.to[axis] = 0;
175 self.impl_next_rec(axis - 1)
176 } else {
177 None
178 }
179 }
180
181 fn new_unchecked(
182 project_from: &'a Count,
183 project_to: &'a Count,
184 from: &'a Count,
185 to: &'a mut Count,
186 ) -> Self {
187 Self {
188 project_from,
189 project_to,
190 from,
191 to,
192 index: 0,
193 }
194 }
195
196 fn project_value(&self) -> f64 {
197 self.project_from
198 .iter()
199 .zip(self.from.iter())
200 .zip(self.project_to.iter())
201 .zip(self.to.iter())
202 .map(|(((&size, &successes), &draws), &observed)| {
203 hypergeometric_pmf(size as u64, successes as u64, draws as u64, observed as u64)
204 })
205 .fold(1.0, |joint, probability| joint * probability)
206 }
207}
208
209impl<'a> Iterator for ProjectIter<'a> {
210 type Item = f64;
211
212 fn next(&mut self) -> Option<Self::Item> {
213 self.impl_next_rec(self.dimensions() - 1)
214 }
215}
216
217#[derive(Debug)]
219pub enum ProjectionError {
220 Empty,
222 InvalidProjection {
224 dimension: usize,
226 from: usize,
228 to: usize,
230 },
231 UnequalDimensions {
233 from: usize,
235 to: usize,
237 },
238 Zero,
240}
241
242impl fmt::Display for ProjectionError {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 match self {
245 ProjectionError::Empty => f.write_str("cannot project empty counts"),
246 ProjectionError::InvalidProjection {
247 dimension,
248 from,
249 to,
250 } => {
251 write!(
252 f,
253 "cannot project from count {from} to count {to} in dimension {dimension}"
254 )
255 }
256 ProjectionError::UnequalDimensions { from, to } => {
257 write!(
258 f,
259 "cannot project from one number of dimensions ({from}) to another ({to})"
260 )
261 }
262 ProjectionError::Zero => f.write_str("cannot project to or from shape zero"),
263 }
264 }
265}
266
267impl std::error::Error for ProjectionError {}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_projection_errors() {
275 assert!(matches!(
276 Projection::new(vec![2, 3], vec![1]),
277 Err(ProjectionError::UnequalDimensions { .. })
278 ));
279
280 assert!(matches!(
281 Projection::new([2, 3], [3, 2]),
282 Err(ProjectionError::InvalidProjection { .. })
283 ))
284 }
285
286 macro_rules! assert_project_to {
287 ($projection:ident from [$($from:literal),+] is [$($expected:literal),+]) => {
288 assert_approx_eq!(
289 $projection
290 .project_unchecked(&Count::from([$($from),+]))
291 .iter
292 .collect::<Vec<_>>(),
293 vec![$($expected),+],
294 epsilon = 1e-6
295 );
296 };
297 }
298
299 #[test]
300 fn test_project_6_to_2() {
301 let mut projection = Projection::new_unchecked(Count::from(6), Count::from(2));
302
303 assert_project_to!(projection from [0] is [1.000000, 0.000000, 0.000000]);
304 assert_project_to!(projection from [1] is [0.666666, 0.333333, 0.000000]);
305 assert_project_to!(projection from [2] is [0.400000, 0.533333, 0.066667]);
306 assert_project_to!(projection from [3] is [0.200000, 0.600000, 0.200000]);
307 assert_project_to!(projection from [4] is [0.066667, 0.533333, 0.400000]);
308 assert_project_to!(projection from [5] is [0.000000, 0.333333, 0.666666]);
309 assert_project_to!(projection from [6] is [0.000000, 0.000000, 1.000000]);
310 }
311
312 #[test]
313 fn test_project_2x2_to_1x1() {
314 let mut projection = Projection::new_unchecked(Count::from([2, 2]), Count::from([1, 1]));
315
316 assert_project_to!(projection from [0, 0] is [1.00, 0.00, 0.00, 0.00]);
317 assert_project_to!(projection from [0, 1] is [0.50, 0.50, 0.00, 0.00]);
318 assert_project_to!(projection from [0, 2] is [0.00, 1.00, 0.00, 0.00]);
319 assert_project_to!(projection from [1, 0] is [0.50, 0.00, 0.50, 0.00]);
320 assert_project_to!(projection from [1, 1] is [0.25, 0.25, 0.25, 0.25]);
321 assert_project_to!(projection from [1, 2] is [0.00, 0.50, 0.00, 0.50]);
322 assert_project_to!(projection from [2, 0] is [0.00, 0.00, 1.00, 0.00]);
323 assert_project_to!(projection from [2, 1] is [0.00, 0.00, 0.50, 0.50]);
324 assert_project_to!(projection from [2, 2] is [0.00, 0.00, 0.00, 1.00]);
325 }
326}