tensorlogic_infer/partitioned/
reducer.rs1use super::config::{AccumulationStrategy, PartitionConfig};
13
14#[derive(Debug, thiserror::Error)]
20pub enum PartitionedError {
21 #[error("Empty input for reduction")]
22 EmptyInput,
23
24 #[error("Chunk size must be > 0, got {0}")]
25 InvalidChunkSize(usize),
26
27 #[error("Shape mismatch: expected {expected:?}, got {got:?}")]
28 ShapeMismatch {
29 expected: Vec<usize>,
30 got: Vec<usize>,
31 },
32
33 #[error("Numerical issue: {0}")]
34 NumericalIssue(String),
35
36 #[error("Axis {axis} out of range for shape {ndim}D tensor")]
37 AxisOutOfRange { axis: usize, ndim: usize },
38}
39
40#[derive(Debug, Clone, Default)]
46pub struct PartitionedStats {
47 pub chunks_processed: usize,
48 pub total_elements_processed: usize,
49 pub peak_chunk_size: usize,
50}
51
52pub struct PartitionedReducer {
59 config: PartitionConfig,
60 stats: PartitionedStats,
61}
62
63impl PartitionedReducer {
64 pub fn new(config: PartitionConfig) -> Self {
66 PartitionedReducer {
67 config,
68 stats: PartitionedStats::default(),
69 }
70 }
71
72 pub fn reduce_all(&mut self, data: &[f64]) -> Result<f64, PartitionedError> {
78 if data.is_empty() {
79 return Err(PartitionedError::EmptyInput);
80 }
81 if self.config.chunk_size == 0 {
82 return Err(PartitionedError::InvalidChunkSize(0));
83 }
84
85 if self.config.accumulation == AccumulationStrategy::LogSumExp {
86 return self.log_sum_exp(data);
87 }
88
89 let (mut acc, needs_count) = self.initial_accumulator();
90 let mut total_count = 0usize;
91
92 for chunk in data.chunks(self.config.chunk_size) {
93 let chunk_len = chunk.len();
94 let chunk_result = self.reduce_chunk(chunk)?;
95 acc = self.combine(acc, chunk_result, &self.config.accumulation)?;
96 total_count += chunk_len;
97 self.stats.chunks_processed += 1;
98 self.stats.total_elements_processed += chunk_len;
99 if chunk_len > self.stats.peak_chunk_size {
100 self.stats.peak_chunk_size = chunk_len;
101 }
102 }
103
104 if needs_count {
105 let count = total_count as f64;
107 if count == 0.0 {
108 return Err(PartitionedError::NumericalIssue(
109 "zero element count for mean".to_string(),
110 ));
111 }
112 acc /= count;
113 }
114
115 Ok(acc)
116 }
117
118 pub fn reduce_axis(
124 &mut self,
125 data: &[f64],
126 shape: &[usize],
127 axis: usize,
128 ) -> Result<(Vec<f64>, Vec<usize>), PartitionedError> {
129 if shape.is_empty() {
130 return Err(PartitionedError::AxisOutOfRange { axis, ndim: 0 });
131 }
132 if axis >= shape.len() {
133 return Err(PartitionedError::AxisOutOfRange {
134 axis,
135 ndim: shape.len(),
136 });
137 }
138
139 let total_elements: usize = shape.iter().product();
140 if data.len() != total_elements {
141 return Err(PartitionedError::ShapeMismatch {
142 expected: shape.to_vec(),
143 got: vec![data.len()],
144 });
145 }
146 if data.is_empty() {
147 return Err(PartitionedError::EmptyInput);
148 }
149
150 let out_shape: Vec<usize> = shape
152 .iter()
153 .enumerate()
154 .filter(|&(i, _)| i != axis)
155 .map(|(_, &d)| d)
156 .collect();
157 let out_len: usize = out_shape.iter().product::<usize>().max(1);
158
159 let stride_before: usize = shape[..axis].iter().product::<usize>().max(1);
163 let axis_len: usize = shape[axis];
164 let stride_after: usize = shape[axis + 1..].iter().product::<usize>().max(1);
165
166 let mut out = vec![self.initial_scalar(); out_len];
167 let mut counts = vec![0usize; out_len];
168
169 for before in 0..stride_before {
172 for after in 0..stride_after {
173 let out_idx = before * stride_after + after;
174 let values: Vec<f64> = (0..axis_len)
176 .map(|k| data[before * axis_len * stride_after + k * stride_after + after])
177 .collect();
178
179 let mut tmp = PartitionedReducer::new(self.config.clone());
181 let reduced = tmp.reduce_all(&values).map_err(|e| match e {
182 PartitionedError::EmptyInput => PartitionedError::EmptyInput,
183 other => other,
184 })?;
185 self.stats.chunks_processed += tmp.stats.chunks_processed;
186 self.stats.total_elements_processed += tmp.stats.total_elements_processed;
187 if tmp.stats.peak_chunk_size > self.stats.peak_chunk_size {
188 self.stats.peak_chunk_size = tmp.stats.peak_chunk_size;
189 }
190
191 out[out_idx] = reduced;
192 counts[out_idx] += axis_len;
193 }
194 }
195
196 let _ = counts;
198
199 Ok((out, out_shape))
200 }
201
202 pub fn log_sum_exp(&self, data: &[f64]) -> Result<f64, PartitionedError> {
206 if data.is_empty() {
207 return Err(PartitionedError::EmptyInput);
208 }
209
210 let mut global_max = f64::NEG_INFINITY;
212 for chunk in data.chunks(self.config.chunk_size.max(1)) {
213 for &x in chunk {
214 if x > global_max {
215 global_max = x;
216 }
217 }
218 }
219
220 if !global_max.is_finite() {
221 return Err(PartitionedError::NumericalIssue(
222 "all -inf values in log_sum_exp input".to_string(),
223 ));
224 }
225
226 let mut sum_exp = 0.0_f64;
228 for chunk in data.chunks(self.config.chunk_size.max(1)) {
229 for &x in chunk {
230 sum_exp += (x - global_max).exp();
231 }
232 }
233
234 if sum_exp <= 0.0 || !sum_exp.is_finite() {
235 return Err(PartitionedError::NumericalIssue(format!(
236 "sum_exp={sum_exp} after max subtraction"
237 )));
238 }
239
240 Ok(global_max + sum_exp.ln())
241 }
242
243 pub fn stats(&self) -> &PartitionedStats {
245 &self.stats
246 }
247
248 pub fn reset_stats(&mut self) {
250 self.stats = PartitionedStats::default();
251 }
252
253 fn reduce_chunk(&self, chunk: &[f64]) -> Result<f64, PartitionedError> {
259 if chunk.is_empty() {
260 return Err(PartitionedError::EmptyInput);
261 }
262 match self.config.accumulation {
263 AccumulationStrategy::Sum | AccumulationStrategy::Mean => Ok(chunk.iter().sum::<f64>()),
264 AccumulationStrategy::Max => chunk
265 .iter()
266 .copied()
267 .reduce(f64::max)
268 .ok_or(PartitionedError::EmptyInput),
269 AccumulationStrategy::Min => chunk
270 .iter()
271 .copied()
272 .reduce(f64::min)
273 .ok_or(PartitionedError::EmptyInput),
274 AccumulationStrategy::Product => Ok(chunk.iter().product::<f64>()),
275 AccumulationStrategy::LogSumExp => {
276 Err(PartitionedError::NumericalIssue(
278 "LogSumExp should be routed through log_sum_exp()".to_string(),
279 ))
280 }
281 }
282 }
283
284 fn combine(
286 &self,
287 acc: f64,
288 new_val: f64,
289 strategy: &AccumulationStrategy,
290 ) -> Result<f64, PartitionedError> {
291 match strategy {
292 AccumulationStrategy::Sum | AccumulationStrategy::Mean => Ok(acc + new_val),
293 AccumulationStrategy::Max => Ok(acc.max(new_val)),
294 AccumulationStrategy::Min => Ok(acc.min(new_val)),
295 AccumulationStrategy::Product => Ok(acc * new_val),
296 AccumulationStrategy::LogSumExp => Err(PartitionedError::NumericalIssue(
297 "LogSumExp should be routed through log_sum_exp()".to_string(),
298 )),
299 }
300 }
301
302 fn initial_accumulator(&self) -> (f64, bool) {
304 match self.config.accumulation {
305 AccumulationStrategy::Sum => (0.0, false),
306 AccumulationStrategy::Mean => (0.0, true), AccumulationStrategy::Max => (f64::NEG_INFINITY, false),
308 AccumulationStrategy::Min => (f64::INFINITY, false),
309 AccumulationStrategy::Product => (1.0, false),
310 AccumulationStrategy::LogSumExp => (0.0, false),
311 }
312 }
313
314 fn initial_scalar(&self) -> f64 {
316 match self.config.accumulation {
317 AccumulationStrategy::Sum | AccumulationStrategy::Mean => 0.0,
318 AccumulationStrategy::Max => f64::NEG_INFINITY,
319 AccumulationStrategy::Min => f64::INFINITY,
320 AccumulationStrategy::Product => 1.0,
321 AccumulationStrategy::LogSumExp => 0.0,
322 }
323 }
324}
325
326#[cfg(test)]
331mod tests {
332 use super::*;
333
334 fn make_reducer(strategy: AccumulationStrategy) -> PartitionedReducer {
335 let cfg = PartitionConfig::new(4).with_strategy(strategy);
336 PartitionedReducer::new(cfg)
337 }
338
339 #[test]
340 fn test_reduce_all_sum() {
341 let data: Vec<f64> = (1..=10).map(|x| x as f64).collect();
342 let mut r = make_reducer(AccumulationStrategy::Sum);
343 let result = r.reduce_all(&data).expect("sum ok");
344 assert!((result - 55.0).abs() < 1e-12, "sum={result} expected=55");
345 }
346
347 #[test]
348 fn test_reduce_all_max() {
349 let data = vec![3.0_f64, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
350 let mut r = make_reducer(AccumulationStrategy::Max);
351 let result = r.reduce_all(&data).expect("max ok");
352 assert!((result - 9.0).abs() < 1e-12, "max={result} expected=9");
353 }
354
355 #[test]
356 fn test_reduce_all_min() {
357 let data = vec![3.0_f64, 1.0, 4.0, 1.0, 5.0, -2.0, 9.0, 6.0];
358 let mut r = make_reducer(AccumulationStrategy::Min);
359 let result = r.reduce_all(&data).expect("min ok");
360 assert!((result - (-2.0)).abs() < 1e-12, "min={result} expected=-2");
361 }
362
363 #[test]
364 fn test_reduce_all_mean() {
365 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
366 let mut r = make_reducer(AccumulationStrategy::Mean);
367 let result = r.reduce_all(&data).expect("mean ok");
368 assert!((result - 3.0).abs() < 1e-10, "mean={result} expected=3.0");
370 }
371
372 #[test]
373 fn test_log_sum_exp_numerically_stable() {
374 let data = vec![1000.0_f64, 1001.0];
376 let cfg = PartitionConfig::new(16).with_strategy(AccumulationStrategy::LogSumExp);
377 let r = PartitionedReducer::new(cfg);
378 let result = r.log_sum_exp(&data).expect("lse ok");
379 let expected = 1000.0_f64 + (1.0_f64 + std::f64::consts::E).ln();
380 assert!(
381 (result - expected).abs() < 1e-10,
382 "lse={result} expected={expected}"
383 );
384 }
385
386 #[test]
387 fn test_empty_input_error() {
388 let mut r = make_reducer(AccumulationStrategy::Sum);
389 let err = r.reduce_all(&[]);
390 assert!(
391 matches!(err, Err(PartitionedError::EmptyInput)),
392 "expected EmptyInput error"
393 );
394 }
395}