tensorlogic_infer/partitioned/
config.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum AccumulationStrategy {
6 Sum,
7 Max,
8 Min,
9 Mean,
10 Product,
11 LogSumExp,
12}
13
14#[derive(Debug, Clone)]
16pub struct PartitionConfig {
17 pub chunk_size: usize,
19 pub max_memory_bytes: Option<usize>,
21 pub accumulation: AccumulationStrategy,
23 pub parallel: bool,
26 pub epsilon: f64,
28}
29
30impl PartitionConfig {
31 pub fn new(chunk_size: usize) -> Self {
33 PartitionConfig {
34 chunk_size,
35 ..Default::default()
36 }
37 }
38
39 pub fn memory_bounded(max_bytes: usize, element_size: usize) -> Self {
44 let chunk_size = max_bytes.checked_div(element_size).unwrap_or(1).max(1);
45 PartitionConfig {
46 chunk_size,
47 max_memory_bytes: Some(max_bytes),
48 ..Default::default()
49 }
50 }
51
52 pub fn with_strategy(mut self, strategy: AccumulationStrategy) -> Self {
54 self.accumulation = strategy;
55 self
56 }
57
58 pub fn with_parallel(mut self, parallel: bool) -> Self {
60 self.parallel = parallel;
61 self
62 }
63
64 pub fn chunks_for_size(&self, total_elements: usize) -> usize {
66 if self.chunk_size == 0 {
67 return 0;
68 }
69 total_elements.div_ceil(self.chunk_size)
70 }
71}
72
73impl Default for PartitionConfig {
74 fn default() -> Self {
75 PartitionConfig {
76 chunk_size: 4096,
77 max_memory_bytes: None,
78 accumulation: AccumulationStrategy::Sum,
79 parallel: false,
80 epsilon: 1e-12,
81 }
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn test_partition_config_new() {
91 let cfg = PartitionConfig::new(1024);
92 assert_eq!(cfg.chunk_size, 1024);
93 assert!(cfg.max_memory_bytes.is_none());
94 assert_eq!(cfg.accumulation, AccumulationStrategy::Sum);
95 assert!(!cfg.parallel);
96 }
97
98 #[test]
99 fn test_partition_config_memory_bounded() {
100 let cfg = PartitionConfig::memory_bounded(64, 8);
102 assert_eq!(cfg.chunk_size, 8);
103 assert_eq!(cfg.max_memory_bytes, Some(64));
104 }
105
106 #[test]
107 fn test_chunks_for_size() {
108 let cfg = PartitionConfig::new(10);
109 assert_eq!(cfg.chunks_for_size(0), 0);
110 assert_eq!(cfg.chunks_for_size(10), 1);
111 assert_eq!(cfg.chunks_for_size(11), 2);
112 assert_eq!(cfg.chunks_for_size(100), 10);
113 assert_eq!(cfg.chunks_for_size(101), 11);
114 }
115}