Skip to main content

tensorlogic_infer/partitioned/
config.rs

1//! Configuration for partitioned (memory-efficient) tensor reductions.
2
3/// Strategy for accumulating partial results across chunks.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum AccumulationStrategy {
6    Sum,
7    Max,
8    Min,
9    Mean,
10    Product,
11    LogSumExp,
12}
13
14/// Configuration for a partitioned reduction.
15#[derive(Debug, Clone)]
16pub struct PartitionConfig {
17    /// Number of elements per partition chunk.
18    pub chunk_size: usize,
19    /// Optional memory budget in bytes.
20    pub max_memory_bytes: Option<usize>,
21    /// How to accumulate across chunks.
22    pub accumulation: AccumulationStrategy,
23    /// Whether to run chunks in parallel (placeholder flag; actual parallelism
24    /// requires a rayon dependency that is not present in this crate).
25    pub parallel: bool,
26    /// Numerical stability epsilon (used in log-sum-exp and division guards).
27    pub epsilon: f64,
28}
29
30impl PartitionConfig {
31    /// Create a new config with the given chunk size and default settings.
32    pub fn new(chunk_size: usize) -> Self {
33        PartitionConfig {
34            chunk_size,
35            ..Default::default()
36        }
37    }
38
39    /// Derive chunk size from a memory budget and the element size in bytes.
40    ///
41    /// The computed chunk size is `max_bytes / element_size`, clamped to at
42    /// least 1.
43    pub fn memory_bounded(max_bytes: usize, element_size: usize) -> Self {
44        let chunk_size = max_bytes.checked_div(element_size).unwrap_or(1).max(1);
45        PartitionConfig {
46            chunk_size,
47            max_memory_bytes: Some(max_bytes),
48            ..Default::default()
49        }
50    }
51
52    /// Set the accumulation strategy.
53    pub fn with_strategy(mut self, strategy: AccumulationStrategy) -> Self {
54        self.accumulation = strategy;
55        self
56    }
57
58    /// Enable or disable parallel chunk processing.
59    pub fn with_parallel(mut self, parallel: bool) -> Self {
60        self.parallel = parallel;
61        self
62    }
63
64    /// Return the number of chunks needed to process `total_elements` elements.
65    pub fn chunks_for_size(&self, total_elements: usize) -> usize {
66        if self.chunk_size == 0 {
67            return 0;
68        }
69        total_elements.div_ceil(self.chunk_size)
70    }
71}
72
73impl Default for PartitionConfig {
74    fn default() -> Self {
75        PartitionConfig {
76            chunk_size: 4096,
77            max_memory_bytes: None,
78            accumulation: AccumulationStrategy::Sum,
79            parallel: false,
80            epsilon: 1e-12,
81        }
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn test_partition_config_new() {
91        let cfg = PartitionConfig::new(1024);
92        assert_eq!(cfg.chunk_size, 1024);
93        assert!(cfg.max_memory_bytes.is_none());
94        assert_eq!(cfg.accumulation, AccumulationStrategy::Sum);
95        assert!(!cfg.parallel);
96    }
97
98    #[test]
99    fn test_partition_config_memory_bounded() {
100        // 64 bytes / 8 bytes per f64 = 8 elements per chunk
101        let cfg = PartitionConfig::memory_bounded(64, 8);
102        assert_eq!(cfg.chunk_size, 8);
103        assert_eq!(cfg.max_memory_bytes, Some(64));
104    }
105
106    #[test]
107    fn test_chunks_for_size() {
108        let cfg = PartitionConfig::new(10);
109        assert_eq!(cfg.chunks_for_size(0), 0);
110        assert_eq!(cfg.chunks_for_size(10), 1);
111        assert_eq!(cfg.chunks_for_size(11), 2);
112        assert_eq!(cfg.chunks_for_size(100), 10);
113        assert_eq!(cfg.chunks_for_size(101), 11);
114    }
115}