Skip to main content

tensorlogic_infer/partitioned/
reducer.rs

1//! Memory-efficient partitioned reduction over flat tensors.
2//!
3//! # Design
4//!
5//! All reductions are processed in fixed-size chunks of at most
6//! `PartitionConfig::chunk_size` elements.  This caps peak working-set memory
7//! independently of total tensor size.
8//!
9//! `reduce_axis` implements an axis-wise reduction by iterating over the
10//! *slices* of the flattened N-D tensor that correspond to a given axis.
11
12use super::config::{AccumulationStrategy, PartitionConfig};
13
14// ---------------------------------------------------------------------------
15// Error type
16// ---------------------------------------------------------------------------
17
18/// Errors produced by [`PartitionedReducer`].
19#[derive(Debug, thiserror::Error)]
20pub enum PartitionedError {
21    #[error("Empty input for reduction")]
22    EmptyInput,
23
24    #[error("Chunk size must be > 0, got {0}")]
25    InvalidChunkSize(usize),
26
27    #[error("Shape mismatch: expected {expected:?}, got {got:?}")]
28    ShapeMismatch {
29        expected: Vec<usize>,
30        got: Vec<usize>,
31    },
32
33    #[error("Numerical issue: {0}")]
34    NumericalIssue(String),
35
36    #[error("Axis {axis} out of range for shape {ndim}D tensor")]
37    AxisOutOfRange { axis: usize, ndim: usize },
38}
39
40// ---------------------------------------------------------------------------
41// PartitionedStats
42// ---------------------------------------------------------------------------
43
44/// Statistics collected during a partitioned reduction.
45#[derive(Debug, Clone, Default)]
46pub struct PartitionedStats {
47    pub chunks_processed: usize,
48    pub total_elements_processed: usize,
49    pub peak_chunk_size: usize,
50}
51
52// ---------------------------------------------------------------------------
53// PartitionedReducer
54// ---------------------------------------------------------------------------
55
56/// Performs memory-efficient reductions by splitting input into fixed-size
57/// chunks and accumulating partial results.
58pub struct PartitionedReducer {
59    config: PartitionConfig,
60    stats: PartitionedStats,
61}
62
63impl PartitionedReducer {
64    /// Create a new reducer with the given configuration.
65    pub fn new(config: PartitionConfig) -> Self {
66        PartitionedReducer {
67            config,
68            stats: PartitionedStats::default(),
69        }
70    }
71
72    // ------------------------------------------------------------------
73    // Public API
74    // ------------------------------------------------------------------
75
76    /// Reduce all elements of a flat 1-D slice to a single scalar.
77    pub fn reduce_all(&mut self, data: &[f64]) -> Result<f64, PartitionedError> {
78        if data.is_empty() {
79            return Err(PartitionedError::EmptyInput);
80        }
81        if self.config.chunk_size == 0 {
82            return Err(PartitionedError::InvalidChunkSize(0));
83        }
84
85        if self.config.accumulation == AccumulationStrategy::LogSumExp {
86            return self.log_sum_exp(data);
87        }
88
89        let (mut acc, needs_count) = self.initial_accumulator();
90        let mut total_count = 0usize;
91
92        for chunk in data.chunks(self.config.chunk_size) {
93            let chunk_len = chunk.len();
94            let chunk_result = self.reduce_chunk(chunk)?;
95            acc = self.combine(acc, chunk_result, &self.config.accumulation)?;
96            total_count += chunk_len;
97            self.stats.chunks_processed += 1;
98            self.stats.total_elements_processed += chunk_len;
99            if chunk_len > self.stats.peak_chunk_size {
100                self.stats.peak_chunk_size = chunk_len;
101            }
102        }
103
104        if needs_count {
105            // Mean: divide accumulated sum by total element count
106            let count = total_count as f64;
107            if count == 0.0 {
108                return Err(PartitionedError::NumericalIssue(
109                    "zero element count for mean".to_string(),
110                ));
111            }
112            acc /= count;
113        }
114
115        Ok(acc)
116    }
117
118    /// Reduce along a single axis of an N-dimensional tensor.
119    ///
120    /// `data` is the row-major flat representation of a tensor with the given
121    /// `shape`.  The returned value is the flat representation of the reduced
122    /// tensor together with its shape (the axis dimension is removed).
123    pub fn reduce_axis(
124        &mut self,
125        data: &[f64],
126        shape: &[usize],
127        axis: usize,
128    ) -> Result<(Vec<f64>, Vec<usize>), PartitionedError> {
129        if shape.is_empty() {
130            return Err(PartitionedError::AxisOutOfRange { axis, ndim: 0 });
131        }
132        if axis >= shape.len() {
133            return Err(PartitionedError::AxisOutOfRange {
134                axis,
135                ndim: shape.len(),
136            });
137        }
138
139        let total_elements: usize = shape.iter().product();
140        if data.len() != total_elements {
141            return Err(PartitionedError::ShapeMismatch {
142                expected: shape.to_vec(),
143                got: vec![data.len()],
144            });
145        }
146        if data.is_empty() {
147            return Err(PartitionedError::EmptyInput);
148        }
149
150        // Compute output shape (remove the axis dimension)
151        let out_shape: Vec<usize> = shape
152            .iter()
153            .enumerate()
154            .filter(|&(i, _)| i != axis)
155            .map(|(_, &d)| d)
156            .collect();
157        let out_len: usize = out_shape.iter().product::<usize>().max(1);
158
159        // stride_before: product of dims before the axis
160        // axis_len: size of the reduced axis
161        // stride_after: product of dims after the axis
162        let stride_before: usize = shape[..axis].iter().product::<usize>().max(1);
163        let axis_len: usize = shape[axis];
164        let stride_after: usize = shape[axis + 1..].iter().product::<usize>().max(1);
165
166        let mut out = vec![self.initial_scalar(); out_len];
167        let mut counts = vec![0usize; out_len];
168
169        // For each element in the output we accumulate all axis values in
170        // chunks to stay within memory budget.
171        for before in 0..stride_before {
172            for after in 0..stride_after {
173                let out_idx = before * stride_after + after;
174                // Collect all values along this axis, then reduce
175                let values: Vec<f64> = (0..axis_len)
176                    .map(|k| data[before * axis_len * stride_after + k * stride_after + after])
177                    .collect();
178
179                // Use reduce_all with a temporary config for the strategy
180                let mut tmp = PartitionedReducer::new(self.config.clone());
181                let reduced = tmp.reduce_all(&values).map_err(|e| match e {
182                    PartitionedError::EmptyInput => PartitionedError::EmptyInput,
183                    other => other,
184                })?;
185                self.stats.chunks_processed += tmp.stats.chunks_processed;
186                self.stats.total_elements_processed += tmp.stats.total_elements_processed;
187                if tmp.stats.peak_chunk_size > self.stats.peak_chunk_size {
188                    self.stats.peak_chunk_size = tmp.stats.peak_chunk_size;
189                }
190
191                out[out_idx] = reduced;
192                counts[out_idx] += axis_len;
193            }
194        }
195
196        // For Mean, the reduce_all already divided by count, nothing more needed.
197        let _ = counts;
198
199        Ok((out, out_shape))
200    }
201
202    /// Numerically stable log-sum-exp: `log(Σ exp(x_i))`.
203    ///
204    /// Uses the max-subtraction trick to avoid overflow.
205    pub fn log_sum_exp(&self, data: &[f64]) -> Result<f64, PartitionedError> {
206        if data.is_empty() {
207            return Err(PartitionedError::EmptyInput);
208        }
209
210        // Pass 1: find max value (chunked to reuse chunk_size discipline)
211        let mut global_max = f64::NEG_INFINITY;
212        for chunk in data.chunks(self.config.chunk_size.max(1)) {
213            for &x in chunk {
214                if x > global_max {
215                    global_max = x;
216                }
217            }
218        }
219
220        if !global_max.is_finite() {
221            return Err(PartitionedError::NumericalIssue(
222                "all -inf values in log_sum_exp input".to_string(),
223            ));
224        }
225
226        // Pass 2: accumulate shifted exponentials
227        let mut sum_exp = 0.0_f64;
228        for chunk in data.chunks(self.config.chunk_size.max(1)) {
229            for &x in chunk {
230                sum_exp += (x - global_max).exp();
231            }
232        }
233
234        if sum_exp <= 0.0 || !sum_exp.is_finite() {
235            return Err(PartitionedError::NumericalIssue(format!(
236                "sum_exp={sum_exp} after max subtraction"
237            )));
238        }
239
240        Ok(global_max + sum_exp.ln())
241    }
242
243    /// Return the accumulated statistics since the last reset.
244    pub fn stats(&self) -> &PartitionedStats {
245        &self.stats
246    }
247
248    /// Reset accumulated statistics.
249    pub fn reset_stats(&mut self) {
250        self.stats = PartitionedStats::default();
251    }
252
253    // ------------------------------------------------------------------
254    // Internal helpers
255    // ------------------------------------------------------------------
256
257    /// Reduce a single chunk according to the configured strategy.
258    fn reduce_chunk(&self, chunk: &[f64]) -> Result<f64, PartitionedError> {
259        if chunk.is_empty() {
260            return Err(PartitionedError::EmptyInput);
261        }
262        match self.config.accumulation {
263            AccumulationStrategy::Sum | AccumulationStrategy::Mean => Ok(chunk.iter().sum::<f64>()),
264            AccumulationStrategy::Max => chunk
265                .iter()
266                .copied()
267                .reduce(f64::max)
268                .ok_or(PartitionedError::EmptyInput),
269            AccumulationStrategy::Min => chunk
270                .iter()
271                .copied()
272                .reduce(f64::min)
273                .ok_or(PartitionedError::EmptyInput),
274            AccumulationStrategy::Product => Ok(chunk.iter().product::<f64>()),
275            AccumulationStrategy::LogSumExp => {
276                // Handled separately in reduce_all
277                Err(PartitionedError::NumericalIssue(
278                    "LogSumExp should be routed through log_sum_exp()".to_string(),
279                ))
280            }
281        }
282    }
283
284    /// Combine a running accumulator with a new chunk result.
285    fn combine(
286        &self,
287        acc: f64,
288        new_val: f64,
289        strategy: &AccumulationStrategy,
290    ) -> Result<f64, PartitionedError> {
291        match strategy {
292            AccumulationStrategy::Sum | AccumulationStrategy::Mean => Ok(acc + new_val),
293            AccumulationStrategy::Max => Ok(acc.max(new_val)),
294            AccumulationStrategy::Min => Ok(acc.min(new_val)),
295            AccumulationStrategy::Product => Ok(acc * new_val),
296            AccumulationStrategy::LogSumExp => Err(PartitionedError::NumericalIssue(
297                "LogSumExp should be routed through log_sum_exp()".to_string(),
298            )),
299        }
300    }
301
302    /// Initial accumulator value for a given strategy.
303    fn initial_accumulator(&self) -> (f64, bool) {
304        match self.config.accumulation {
305            AccumulationStrategy::Sum => (0.0, false),
306            AccumulationStrategy::Mean => (0.0, true), // divide by count at end
307            AccumulationStrategy::Max => (f64::NEG_INFINITY, false),
308            AccumulationStrategy::Min => (f64::INFINITY, false),
309            AccumulationStrategy::Product => (1.0, false),
310            AccumulationStrategy::LogSumExp => (0.0, false),
311        }
312    }
313
314    /// Scalar identity for axis reduction initialisation.
315    fn initial_scalar(&self) -> f64 {
316        match self.config.accumulation {
317            AccumulationStrategy::Sum | AccumulationStrategy::Mean => 0.0,
318            AccumulationStrategy::Max => f64::NEG_INFINITY,
319            AccumulationStrategy::Min => f64::INFINITY,
320            AccumulationStrategy::Product => 1.0,
321            AccumulationStrategy::LogSumExp => 0.0,
322        }
323    }
324}
325
326// ---------------------------------------------------------------------------
327// Tests
328// ---------------------------------------------------------------------------
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    fn make_reducer(strategy: AccumulationStrategy) -> PartitionedReducer {
335        let cfg = PartitionConfig::new(4).with_strategy(strategy);
336        PartitionedReducer::new(cfg)
337    }
338
339    #[test]
340    fn test_reduce_all_sum() {
341        let data: Vec<f64> = (1..=10).map(|x| x as f64).collect();
342        let mut r = make_reducer(AccumulationStrategy::Sum);
343        let result = r.reduce_all(&data).expect("sum ok");
344        assert!((result - 55.0).abs() < 1e-12, "sum={result} expected=55");
345    }
346
347    #[test]
348    fn test_reduce_all_max() {
349        let data = vec![3.0_f64, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
350        let mut r = make_reducer(AccumulationStrategy::Max);
351        let result = r.reduce_all(&data).expect("max ok");
352        assert!((result - 9.0).abs() < 1e-12, "max={result} expected=9");
353    }
354
355    #[test]
356    fn test_reduce_all_min() {
357        let data = vec![3.0_f64, 1.0, 4.0, 1.0, 5.0, -2.0, 9.0, 6.0];
358        let mut r = make_reducer(AccumulationStrategy::Min);
359        let result = r.reduce_all(&data).expect("min ok");
360        assert!((result - (-2.0)).abs() < 1e-12, "min={result} expected=-2");
361    }
362
363    #[test]
364    fn test_reduce_all_mean() {
365        let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
366        let mut r = make_reducer(AccumulationStrategy::Mean);
367        let result = r.reduce_all(&data).expect("mean ok");
368        // Mean of 1..5 = 15/5 = 3.0
369        assert!((result - 3.0).abs() < 1e-10, "mean={result} expected=3.0");
370    }
371
372    #[test]
373    fn test_log_sum_exp_numerically_stable() {
374        // log(exp(1000) + exp(1001)) = log(exp(1000)(1 + exp(1))) = 1000 + log(1 + e)
375        let data = vec![1000.0_f64, 1001.0];
376        let cfg = PartitionConfig::new(16).with_strategy(AccumulationStrategy::LogSumExp);
377        let r = PartitionedReducer::new(cfg);
378        let result = r.log_sum_exp(&data).expect("lse ok");
379        let expected = 1000.0_f64 + (1.0_f64 + std::f64::consts::E).ln();
380        assert!(
381            (result - expected).abs() < 1e-10,
382            "lse={result} expected={expected}"
383        );
384    }
385
386    #[test]
387    fn test_empty_input_error() {
388        let mut r = make_reducer(AccumulationStrategy::Sum);
389        let err = r.reduce_all(&[]);
390        assert!(
391            matches!(err, Err(PartitionedError::EmptyInput)),
392            "expected EmptyInput error"
393        );
394    }
395}