1use std::{cmp::Ordering, marker::PhantomData};
2
3use crate::Array;
4
5use super::{Scs, Spectrum, State};
6
7#[derive(Debug, PartialEq)]
9pub struct Folded<S: State> {
10 array: Array<Option<f64>>, state: PhantomData<S>,
12}
13
14impl<S: State> Folded<S> {
15 pub(super) fn from_spectrum(spectrum: &Spectrum<S>) -> Self {
16 let n = spectrum.elements();
17 let total_count = spectrum.shape().iter().sum::<usize>() - spectrum.shape().len();
18
19 let mid_count = total_count / 2;
22
23 let has_diagonal = total_count % 2 == 0;
42
43 let mut array = Array::from_element(None, spectrum.shape().clone());
47
48 (0..n).zip((0..n).rev()).for_each(|(i, rev_i)| {
51 let count = spectrum.shape().index_sum_from_flat_unchecked(i);
52
53 let src = spectrum.array.as_slice();
54 let dst = array.as_mut_slice();
55
56 match (count.cmp(&mid_count), has_diagonal) {
57 (Ordering::Less, _) | (Ordering::Equal, false) => {
58 dst[i] = Some(src[i] + src[rev_i]);
60 }
61 (Ordering::Equal, true) => {
62 dst[i] = Some(0.5 * src[i] + 0.5 * src[rev_i]);
66 }
67 (Ordering::Greater, _) => {
68 dst[i] = None;
70 }
71 }
72 });
73
74 Self {
75 array,
76 state: PhantomData,
77 }
78 }
79
80 pub fn into_spectrum(&self, fill: f64) -> Spectrum<S> {
83 let data = Vec::from_iter(self.array.iter().map(|x| x.unwrap_or(fill)));
84 let shape = self.array.shape().clone();
85 let array = Array::new_unchecked(data, shape);
86
87 Scs::from(array).into_state_unchecked()
88 }
89}
90
91impl<S: State> Clone for Folded<S> {
92 fn clone(&self) -> Self {
93 Self {
94 array: self.array.clone(),
95 state: PhantomData,
96 }
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_fold_4() {
106 let scs = Scs::from_range(0..4, 4).unwrap();
107 let expected = Scs::new([3., 3., 0., 0.], 4).unwrap();
108 assert_eq!(scs.fold().into_spectrum(0.0), expected);
109 }
110
111 #[test]
112 fn test_fold_5() {
113 let scs = Scs::from_range(0..5, 5).unwrap();
114 let expected = Scs::new([4., 4., 2., -1., -1.], 5).unwrap();
115 assert_eq!(scs.fold().into_spectrum(-1.), expected);
116 }
117
118 #[test]
119 fn test_fold_3x3() {
120 let scs = Scs::from_range(0..9, [3, 3]).unwrap();
121
122 #[rustfmt::skip]
123 let expected = Scs::new(
124 [
125 8., 8., 4.,
126 8., 4., 0.,
127 4., 0., 0.,
128 ],
129 [3, 3]
130 ).unwrap();
131
132 assert_eq!(scs.fold().into_spectrum(0.0), expected);
133 }
134
135 #[test]
136 fn test_fold_2x4() {
137 let scs = Scs::from_range(0..8, [2, 4]).unwrap();
138
139 #[rustfmt::skip]
140 let expected = Scs::new(
141 [
142 7., 7., 3.5, f64::INFINITY,
143 7., 3.5, f64::INFINITY, f64::INFINITY,
144 ],
145 [2, 4]
146 ).unwrap();
147
148 assert_eq!(scs.fold().into_spectrum(f64::INFINITY), expected);
149 }
150
151 #[test]
152 fn test_fold_3x4() {
153 let scs = Scs::from_range(0..12, [3, 4]).unwrap();
154
155 #[rustfmt::skip]
156 let expected = Scs::new(
157 [
158 11., 11., 11., 0.,
159 11., 11., 0., 0.,
160 11., 0., 0., 0.,
161 ],
162 [3, 4]
163 ).unwrap();
164
165 assert_eq!(scs.fold().into_spectrum(0.), expected);
166 }
167
168 #[test]
169 fn test_fold_3x7() {
170 let scs = Scs::from_range(0..21, [3, 7]).unwrap();
171
172 #[rustfmt::skip]
173 let expected = Scs::new(
174 [
175 20., 20., 20., 20., 10., 0., 0.,
176 20., 20., 20., 10., 0., 0., 0.,
177 20., 20., 10., 0., 0., 0., 0.,
178 ],
179 [3, 7]
180 ).unwrap();
181
182 assert_eq!(scs.fold().into_spectrum(0.0), expected);
183 }
184
185 #[test]
186 fn test_fold_2x2x2() {
187 let scs = Scs::from_range(0..8, [2, 2, 2]).unwrap();
188
189 #[rustfmt::skip]
190 let expected = Scs::new(
191 [
192 7., 7.,
193 7., -1.,
194
195 7., -1.,
196 -1., -1.,
197 ],
198 [2, 2, 2]
199 ).unwrap();
200
201 assert_eq!(scs.fold().into_spectrum(-1.0), expected);
202 }
203
204 #[test]
205 fn test_fold_2x3x2() {
206 let scs = Scs::from_range(0..12, [2, 3, 2]).unwrap();
207
208 #[rustfmt::skip]
209 let expected = Scs::new(
210 [
211 11., 11.,
212 11., 5.5,
213 5.5, 0.,
214
215 11., 5.5,
216 5.5, 0.,
217 0., 0.,
218 ],
219 [2, 3, 2]
220 ).unwrap();
221
222 assert_eq!(scs.fold().into_spectrum(0.0), expected);
223 }
224
225 #[test]
226 fn test_fold_3x3x3() {
227 let scs = Scs::from_range(0..27, [3, 3, 3]).unwrap();
228
229 #[rustfmt::skip]
230 let expected = Scs::new(
231 [
232 26., 26., 26.,
233 26., 26., 13.,
234 26., 13., 0.,
235
236 26., 26., 13.,
237 26., 13., 0.,
238 13., 0., 0.,
239
240 26., 13., 0.,
241 13., 0., 0.,
242 0., 0., 0.,
243 ],
244 [3, 3, 3]
245 ).unwrap();
246
247 assert_eq!(scs.fold().into_spectrum(0.0), expected);
248 }
249}