1use scirs2_core::ndarray::{ArrayD, IxDyn};
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15use crate::error::{PgmError, Result};
16use crate::Factor;
17
18#[derive(Debug)]
23pub struct FactorPool {
24 pools: Mutex<HashMap<usize, Vec<Vec<f64>>>>,
26 stats: Mutex<PoolStats>,
28 max_pool_size: usize,
30}
31
32#[derive(Debug, Clone, Default)]
34pub struct PoolStats {
35 pub hits: usize,
37 pub misses: usize,
39 pub returns: usize,
41 pub peak_bytes: usize,
43 pub current_bytes: usize,
45}
46
47impl Default for FactorPool {
48 fn default() -> Self {
49 Self::new(100)
50 }
51}
52
53impl FactorPool {
54 pub fn new(max_pool_size: usize) -> Self {
56 Self {
57 pools: Mutex::new(HashMap::new()),
58 stats: Mutex::new(PoolStats::default()),
59 max_pool_size,
60 }
61 }
62
63 pub fn allocate(&self, size: usize) -> Vec<f64> {
65 let mut pools = self.pools.lock().expect("lock should not be poisoned");
66 let mut stats = self.stats.lock().expect("lock should not be poisoned");
67
68 if let Some(pool) = pools.get_mut(&size) {
69 if let Some(array) = pool.pop() {
70 stats.hits += 1;
71 stats.current_bytes -= size * std::mem::size_of::<f64>();
72 return array;
73 }
74 }
75
76 stats.misses += 1;
77 vec![0.0; size]
78 }
79
80 pub fn return_array(&self, mut array: Vec<f64>) {
82 let size = array.len();
83 let mut pools = self.pools.lock().expect("lock should not be poisoned");
84 let mut stats = self.stats.lock().expect("lock should not be poisoned");
85
86 let pool = pools.entry(size).or_default();
87 if pool.len() < self.max_pool_size {
88 array.fill(0.0);
90 pool.push(array);
91 stats.returns += 1;
92 stats.current_bytes += size * std::mem::size_of::<f64>();
93 stats.peak_bytes = stats.peak_bytes.max(stats.current_bytes);
94 }
95 }
97
98 pub fn stats(&self) -> PoolStats {
100 self.stats
101 .lock()
102 .expect("lock should not be poisoned")
103 .clone()
104 }
105
106 pub fn clear(&self) {
108 let mut pools = self.pools.lock().expect("lock should not be poisoned");
109 let mut stats = self.stats.lock().expect("lock should not be poisoned");
110 pools.clear();
111 stats.current_bytes = 0;
112 }
113
114 pub fn hit_rate(&self) -> f64 {
116 let stats = self.stats.lock().expect("lock should not be poisoned");
117 let total = stats.hits + stats.misses;
118 if total > 0 {
119 stats.hits as f64 / total as f64
120 } else {
121 0.0
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
130pub struct SparseFactor {
131 pub variables: Vec<String>,
133 pub cardinalities: Vec<usize>,
135 pub entries: Vec<(Vec<usize>, f64)>,
137 pub default_value: f64,
139}
140
141impl SparseFactor {
142 pub fn new(variables: Vec<String>, cardinalities: Vec<usize>) -> Self {
144 Self {
145 variables,
146 cardinalities,
147 entries: Vec::new(),
148 default_value: 0.0,
149 }
150 }
151
152 pub fn from_dense(factor: &Factor, threshold: f64) -> Self {
156 let shape: Vec<usize> = factor.values.shape().to_vec();
157 let mut sparse = Self::new(factor.variables.clone(), shape.clone());
158 sparse.default_value = 0.0;
159
160 let total_size: usize = shape.iter().product();
161
162 for i in 0..total_size {
163 let indices = Self::flat_to_indices(i, &shape);
164 let value = factor.values[indices.as_slice()];
165
166 if value.abs() > threshold {
167 sparse.entries.push((indices, value));
168 }
169 }
170
171 sparse
172 }
173
174 pub fn to_dense(&self) -> Result<Factor> {
176 let total_size: usize = self.cardinalities.iter().product();
177 let mut values = vec![self.default_value; total_size];
178
179 for (indices, value) in &self.entries {
180 let flat_idx = Self::indices_to_flat(indices, &self.cardinalities);
181 values[flat_idx] = *value;
182 }
183
184 let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
185
186 Factor::new("sparse".to_string(), self.variables.clone(), array)
187 }
188
189 pub fn get(&self, indices: &[usize]) -> f64 {
191 for (entry_indices, value) in &self.entries {
192 if entry_indices == indices {
193 return *value;
194 }
195 }
196 self.default_value
197 }
198
199 pub fn set(&mut self, indices: Vec<usize>, value: f64) {
201 for (entry_indices, entry_value) in &mut self.entries {
203 if *entry_indices == indices {
204 *entry_value = value;
205 return;
206 }
207 }
208
209 if (value - self.default_value).abs() > 1e-10 {
211 self.entries.push((indices, value));
212 }
213 }
214
215 pub fn sparsity(&self) -> f64 {
217 let total_size: usize = self.cardinalities.iter().product();
218 if total_size > 0 {
219 1.0 - (self.entries.len() as f64 / total_size as f64)
220 } else {
221 1.0
222 }
223 }
224
225 pub fn memory_savings(&self) -> f64 {
227 let dense_bytes = self.cardinalities.iter().product::<usize>() * std::mem::size_of::<f64>();
228 let sparse_bytes = self.entries.len()
229 * (self.variables.len() * std::mem::size_of::<usize>() + std::mem::size_of::<f64>());
230
231 if dense_bytes > 0 {
232 1.0 - (sparse_bytes as f64 / dense_bytes as f64)
233 } else {
234 0.0
235 }
236 }
237
238 fn flat_to_indices(flat: usize, shape: &[usize]) -> Vec<usize> {
240 let mut indices = vec![0; shape.len()];
241 let mut remaining = flat;
242
243 for i in (0..shape.len()).rev() {
244 indices[i] = remaining % shape[i];
245 remaining /= shape[i];
246 }
247
248 indices
249 }
250
251 fn indices_to_flat(indices: &[usize], shape: &[usize]) -> usize {
253 let mut flat = 0;
254 let mut stride = 1;
255
256 for i in (0..shape.len()).rev() {
257 flat += indices[i] * stride;
258 stride *= shape[i];
259 }
260
261 flat
262 }
263}
264
265#[derive(Clone)]
269pub struct LazyFactor {
270 computation: Arc<dyn Fn() -> Result<Factor> + Send + Sync>,
272 cached: Arc<Mutex<Option<Factor>>>,
274}
275
276impl std::fmt::Debug for LazyFactor {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 f.debug_struct("LazyFactor")
279 .field(
280 "cached",
281 &self
282 .cached
283 .lock()
284 .expect("lock should not be poisoned")
285 .is_some(),
286 )
287 .finish()
288 }
289}
290
291impl LazyFactor {
292 pub fn new<F>(computation: F) -> Self
294 where
295 F: Fn() -> Result<Factor> + Send + Sync + 'static,
296 {
297 Self {
298 computation: Arc::new(computation),
299 cached: Arc::new(Mutex::new(None)),
300 }
301 }
302
303 pub fn from_factor(factor: Factor) -> Self {
305 Self {
306 computation: Arc::new(move || {
307 Err(PgmError::InvalidDistribution(
308 "Already computed".to_string(),
309 ))
310 }),
311 cached: Arc::new(Mutex::new(Some(factor))),
312 }
313 }
314
315 pub fn evaluate(&self) -> Result<Factor> {
317 let mut cached = self.cached.lock().expect("lock should not be poisoned");
318
319 if let Some(ref factor) = *cached {
320 return Ok(factor.clone());
321 }
322
323 let result = (self.computation)()?;
324 *cached = Some(result.clone());
325 Ok(result)
326 }
327
328 pub fn is_computed(&self) -> bool {
330 self.cached
331 .lock()
332 .expect("lock should not be poisoned")
333 .is_some()
334 }
335
336 pub fn clear_cache(&self) {
338 let mut cached = self.cached.lock().expect("lock should not be poisoned");
339 *cached = None;
340 }
341
342 pub fn lazy_product(a: LazyFactor, b: LazyFactor) -> LazyFactor {
344 LazyFactor::new(move || {
345 let factor_a = a.evaluate()?;
346 let factor_b = b.evaluate()?;
347 factor_a.product(&factor_b)
348 })
349 }
350
351 pub fn lazy_marginalize(factor: LazyFactor, var: String) -> LazyFactor {
353 LazyFactor::new(move || {
354 let f = factor.evaluate()?;
355 f.marginalize_out(&var)
356 })
357 }
358}
359
360pub struct StreamingFactorGraph {
364 variables: HashMap<String, VariableInfo>,
366 factor_generators: Vec<Box<dyn Fn() -> Result<Factor> + Send + Sync>>,
368 #[allow(dead_code)]
370 pool: Arc<FactorPool>,
371}
372
373#[derive(Debug, Clone)]
375#[allow(dead_code)]
376struct VariableInfo {
377 domain: String,
378 cardinality: usize,
379}
380
381impl StreamingFactorGraph {
382 pub fn new() -> Self {
384 Self {
385 variables: HashMap::new(),
386 factor_generators: Vec::new(),
387 pool: Arc::new(FactorPool::default()),
388 }
389 }
390
391 pub fn with_pool(pool: Arc<FactorPool>) -> Self {
393 Self {
394 variables: HashMap::new(),
395 factor_generators: Vec::new(),
396 pool,
397 }
398 }
399
400 pub fn add_variable(&mut self, name: String, domain: String, cardinality: usize) {
402 self.variables.insert(
403 name,
404 VariableInfo {
405 domain,
406 cardinality,
407 },
408 );
409 }
410
411 pub fn add_factor<F>(&mut self, generator: F)
413 where
414 F: Fn() -> Result<Factor> + Send + Sync + 'static,
415 {
416 self.factor_generators.push(Box::new(generator));
417 }
418
419 pub fn stream_factors(&self) -> impl Iterator<Item = Result<Factor>> + '_ {
421 self.factor_generators.iter().map(|gen| gen())
422 }
423
424 pub fn streaming_product(&self) -> Result<Factor> {
428 let mut result: Option<Factor> = None;
429
430 for gen in &self.factor_generators {
431 let factor = gen()?;
432
433 result = match result {
434 None => Some(factor),
435 Some(r) => Some(r.product(&factor)?),
436 };
437 }
438
439 result.ok_or_else(|| PgmError::InvalidDistribution("No factors in graph".to_string()))
440 }
441
442 pub fn num_variables(&self) -> usize {
444 self.variables.len()
445 }
446
447 pub fn num_factors(&self) -> usize {
449 self.factor_generators.len()
450 }
451}
452
453impl Default for StreamingFactorGraph {
454 fn default() -> Self {
455 Self::new()
456 }
457}
458
459#[derive(Debug, Clone)]
463pub struct CompressedFactor {
464 pub variables: Vec<String>,
466 pub cardinalities: Vec<usize>,
468 quantized: Vec<u16>,
470 min_value: f64,
472 scale: f64,
474}
475
476impl CompressedFactor {
477 pub fn from_factor(factor: &Factor) -> Self {
479 let values: Vec<f64> = factor.values.iter().copied().collect();
480 let cardinalities: Vec<usize> = factor.values.shape().to_vec();
481
482 let min_value = values.iter().copied().fold(f64::INFINITY, f64::min);
483 let max_value = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
484
485 let scale = if max_value > min_value {
486 (max_value - min_value) / 65535.0
487 } else {
488 1.0
489 };
490
491 let quantized: Vec<u16> = values
492 .iter()
493 .map(|&v| ((v - min_value) / scale).round() as u16)
494 .collect();
495
496 Self {
497 variables: factor.variables.clone(),
498 cardinalities,
499 quantized,
500 min_value,
501 scale,
502 }
503 }
504
505 pub fn to_factor(&self) -> Result<Factor> {
507 let values: Vec<f64> = self
508 .quantized
509 .iter()
510 .map(|&q| self.min_value + (q as f64) * self.scale)
511 .collect();
512
513 let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
514
515 Factor::new("compressed".to_string(), self.variables.clone(), array)
516 }
517
518 pub fn memory_size(&self) -> usize {
520 self.quantized.len() * std::mem::size_of::<u16>()
521 + self.variables.len() * std::mem::size_of::<String>()
522 + self.cardinalities.len() * std::mem::size_of::<usize>()
523 + 2 * std::mem::size_of::<f64>()
524 }
525
526 pub fn compression_ratio(&self) -> f64 {
528 let original = self.quantized.len() * std::mem::size_of::<f64>();
529 let compressed = self.quantized.len() * std::mem::size_of::<u16>();
530
531 if compressed > 0 {
532 original as f64 / compressed as f64
533 } else {
534 1.0
535 }
536 }
537}
538
539#[derive(Debug, Clone)]
543pub struct BlockSparseFactor {
544 pub variables: Vec<String>,
546 pub cardinalities: Vec<usize>,
548 pub block_size: usize,
550 blocks: HashMap<Vec<usize>, Vec<f64>>,
552 default_value: f64,
554}
555
556impl BlockSparseFactor {
557 pub fn new(variables: Vec<String>, cardinalities: Vec<usize>, block_size: usize) -> Self {
559 Self {
560 variables,
561 cardinalities,
562 block_size,
563 blocks: HashMap::new(),
564 default_value: 0.0,
565 }
566 }
567
568 pub fn from_factor(factor: &Factor, block_size: usize, threshold: f64) -> Self {
570 let shape: Vec<usize> = factor.values.shape().to_vec();
571 let mut sparse = Self::new(factor.variables.clone(), shape.clone(), block_size);
572 sparse.default_value = 0.0;
573 let block_dims: Vec<usize> = shape.iter().map(|&d| d.div_ceil(block_size)).collect();
574
575 let total_blocks: usize = block_dims.iter().product();
577 for block_flat in 0..total_blocks {
578 let block_indices = SparseFactor::flat_to_indices(block_flat, &block_dims);
579
580 let block_total = block_size.pow(shape.len() as u32);
582 let mut block_values = Vec::with_capacity(block_total);
583 let mut has_nonzero = false;
584
585 for local_flat in 0..block_total {
586 let local_indices =
587 SparseFactor::flat_to_indices(local_flat, &vec![block_size; shape.len()]);
588
589 let global_indices: Vec<usize> = block_indices
591 .iter()
592 .zip(local_indices.iter())
593 .zip(shape.iter())
594 .map(|((&bi, &li), &s)| (bi * block_size + li).min(s - 1))
595 .collect();
596
597 let value = factor.values[global_indices.as_slice()];
598 block_values.push(value);
599
600 if value.abs() > threshold {
601 has_nonzero = true;
602 }
603 }
604
605 if has_nonzero {
606 sparse.blocks.insert(block_indices, block_values);
607 }
608 }
609
610 sparse
611 }
612
613 pub fn num_blocks(&self) -> usize {
615 self.blocks.len()
616 }
617
618 pub fn block_sparsity(&self) -> f64 {
620 let block_dims: Vec<usize> = self
621 .cardinalities
622 .iter()
623 .map(|&d| d.div_ceil(self.block_size))
624 .collect();
625 let total_blocks: usize = block_dims.iter().product();
626
627 if total_blocks > 0 {
628 1.0 - (self.blocks.len() as f64 / total_blocks as f64)
629 } else {
630 1.0
631 }
632 }
633}
634
635pub fn estimate_memory_usage(
637 num_variables: usize,
638 avg_cardinality: usize,
639 num_factors: usize,
640 avg_scope_size: usize,
641) -> MemoryEstimate {
642 let bytes_per_entry = std::mem::size_of::<f64>();
643 let avg_factor_size = avg_cardinality.pow(avg_scope_size as u32);
644 let total_factor_bytes = num_factors * avg_factor_size * bytes_per_entry;
645
646 let edges = num_factors * avg_scope_size;
648 let message_bytes = 2 * edges * avg_cardinality * bytes_per_entry;
649
650 let marginal_bytes = num_variables * avg_cardinality * bytes_per_entry;
652
653 MemoryEstimate {
654 factor_bytes: total_factor_bytes,
655 message_bytes,
656 marginal_bytes,
657 total_bytes: total_factor_bytes + message_bytes + marginal_bytes,
658 }
659}
660
661#[derive(Debug, Clone)]
663pub struct MemoryEstimate {
664 pub factor_bytes: usize,
666 pub message_bytes: usize,
668 pub marginal_bytes: usize,
670 pub total_bytes: usize,
672}
673
674impl std::fmt::Display for MemoryEstimate {
675 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
676 let to_mb = |bytes: usize| bytes as f64 / 1_048_576.0;
677 write!(
678 f,
679 "Memory Estimate: {:.2} MB total (factors: {:.2} MB, messages: {:.2} MB, marginals: {:.2} MB)",
680 to_mb(self.total_bytes),
681 to_mb(self.factor_bytes),
682 to_mb(self.message_bytes),
683 to_mb(self.marginal_bytes)
684 )
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691 use approx::assert_abs_diff_eq;
692 use scirs2_core::ndarray::Array;
693
694 #[test]
695 fn test_factor_pool_allocation() {
696 let pool = FactorPool::new(10);
697
698 let arr1 = pool.allocate(100);
699 assert_eq!(arr1.len(), 100);
700
701 pool.return_array(arr1);
702 assert_eq!(pool.stats().returns, 1);
703
704 let arr2 = pool.allocate(100);
706 assert_eq!(arr2.len(), 100);
707 assert_eq!(pool.stats().hits, 1);
708 }
709
710 #[test]
711 fn test_factor_pool_hit_rate() {
712 let pool = FactorPool::new(10);
713
714 let arr = pool.allocate(50);
716 pool.return_array(arr);
717
718 let _ = pool.allocate(50);
720
721 assert!(pool.hit_rate() > 0.4); }
723
724 #[test]
725 fn test_sparse_factor_creation() {
726 let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![4]);
727
728 sparse.set(vec![0], 1.0);
729 sparse.set(vec![2], 0.5);
730
731 assert_abs_diff_eq!(sparse.get(&[0]), 1.0, epsilon = 1e-10);
732 assert_abs_diff_eq!(sparse.get(&[1]), 0.0, epsilon = 1e-10);
733 assert_abs_diff_eq!(sparse.get(&[2]), 0.5, epsilon = 1e-10);
734 }
735
736 #[test]
737 fn test_sparse_factor_from_dense() {
738 let factor = Factor::new(
739 "test".to_string(),
740 vec!["x".to_string()],
741 Array::from_vec(vec![0.0, 1.0, 0.0, 0.5]).into_dyn(),
742 )
743 .expect("unwrap");
744
745 let sparse = SparseFactor::from_dense(&factor, 0.1);
746 assert_eq!(sparse.entries.len(), 2); let dense = sparse.to_dense().expect("unwrap");
749 assert_abs_diff_eq!(dense.values[[1]], 1.0, epsilon = 1e-10);
750 assert_abs_diff_eq!(dense.values[[3]], 0.5, epsilon = 1e-10);
751 }
752
753 #[test]
754 fn test_sparse_factor_sparsity() {
755 let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![100]);
756 sparse.set(vec![50], 1.0);
757
758 let sparsity = sparse.sparsity();
759 assert!(sparsity > 0.98); }
761
762 #[test]
763 fn test_lazy_factor_deferred() {
764 let counter = Arc::new(Mutex::new(0));
765 let counter_clone = counter.clone();
766
767 let lazy = LazyFactor::new(move || {
768 let mut count = counter_clone.lock().expect("unwrap");
769 *count += 1;
770 Factor::new(
771 "test".to_string(),
772 vec!["x".to_string()],
773 Array::from_vec(vec![0.5, 0.5]).into_dyn(),
774 )
775 });
776
777 assert!(!lazy.is_computed());
778 assert_eq!(*counter.lock().expect("unwrap"), 0);
779
780 let _ = lazy.evaluate().expect("unwrap");
781 assert!(lazy.is_computed());
782 assert_eq!(*counter.lock().expect("unwrap"), 1);
783
784 let _ = lazy.evaluate().expect("unwrap");
786 assert_eq!(*counter.lock().expect("unwrap"), 1);
787 }
788
789 #[test]
790 fn test_lazy_factor_from_factor() {
791 let factor = Factor::new(
792 "test".to_string(),
793 vec!["x".to_string()],
794 Array::from_vec(vec![0.3, 0.7]).into_dyn(),
795 )
796 .expect("unwrap");
797
798 let lazy = LazyFactor::from_factor(factor);
799 assert!(lazy.is_computed());
800
801 let result = lazy.evaluate().expect("unwrap");
802 assert_abs_diff_eq!(result.values[[0]], 0.3, epsilon = 1e-10);
803 }
804
805 #[test]
806 fn test_compressed_factor() {
807 let factor = Factor::new(
808 "test".to_string(),
809 vec!["x".to_string()],
810 Array::from_vec(vec![0.1, 0.2, 0.3, 0.4]).into_dyn(),
811 )
812 .expect("unwrap");
813
814 let compressed = CompressedFactor::from_factor(&factor);
815 let decompressed = compressed.to_factor().expect("unwrap");
816
817 for i in 0..4 {
819 assert_abs_diff_eq!(factor.values[[i]], decompressed.values[[i]], epsilon = 0.01);
820 }
821 }
822
823 #[test]
824 fn test_compressed_factor_ratio() {
825 let factor = Factor::new(
826 "test".to_string(),
827 vec!["x".to_string(), "y".to_string()],
828 ArrayD::from_elem(IxDyn(&[10, 10]), 0.5),
829 )
830 .expect("unwrap");
831
832 let compressed = CompressedFactor::from_factor(&factor);
833 let ratio = compressed.compression_ratio();
834
835 assert!(ratio > 3.5);
837 }
838
839 #[test]
840 fn test_streaming_factor_graph() {
841 let mut graph = StreamingFactorGraph::new();
842 graph.add_variable("x".to_string(), "Binary".to_string(), 2);
843 graph.add_variable("y".to_string(), "Binary".to_string(), 2);
844
845 graph.add_factor(|| {
846 Factor::new(
847 "factor_x".to_string(),
848 vec!["x".to_string()],
849 Array::from_vec(vec![0.3, 0.7]).into_dyn(),
850 )
851 });
852
853 graph.add_factor(|| {
854 Factor::new(
855 "factor_y".to_string(),
856 vec!["y".to_string()],
857 Array::from_vec(vec![0.4, 0.6]).into_dyn(),
858 )
859 });
860
861 assert_eq!(graph.num_variables(), 2);
862 assert_eq!(graph.num_factors(), 2);
863 }
864
865 #[test]
866 fn test_memory_estimate() {
867 let estimate = estimate_memory_usage(10, 3, 20, 3);
868
869 assert!(estimate.total_bytes > 0);
870 assert!(estimate.factor_bytes > 0);
871 assert!(estimate.message_bytes > 0);
872 }
873
874 #[test]
875 fn test_block_sparse_factor() {
876 let factor = Factor::new(
877 "test".to_string(),
878 vec!["x".to_string(), "y".to_string()],
879 ArrayD::from_elem(IxDyn(&[8, 8]), 0.0),
880 )
881 .expect("unwrap");
882
883 let block_sparse = BlockSparseFactor::from_factor(&factor, 4, 0.001);
884
885 let sparsity = block_sparse.block_sparsity();
887 assert!(sparsity > 0.99);
888 }
889}