1use scirs2_core::ndarray::{ArrayD, IxDyn};
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15use crate::error::{PgmError, Result};
16use crate::Factor;
17
18#[derive(Debug)]
23pub struct FactorPool {
24 pools: Mutex<HashMap<usize, Vec<Vec<f64>>>>,
26 stats: Mutex<PoolStats>,
28 max_pool_size: usize,
30}
31
32#[derive(Debug, Clone, Default)]
34pub struct PoolStats {
35 pub hits: usize,
37 pub misses: usize,
39 pub returns: usize,
41 pub peak_bytes: usize,
43 pub current_bytes: usize,
45}
46
47impl Default for FactorPool {
48 fn default() -> Self {
49 Self::new(100)
50 }
51}
52
53impl FactorPool {
54 pub fn new(max_pool_size: usize) -> Self {
56 Self {
57 pools: Mutex::new(HashMap::new()),
58 stats: Mutex::new(PoolStats::default()),
59 max_pool_size,
60 }
61 }
62
63 pub fn allocate(&self, size: usize) -> Vec<f64> {
65 let mut pools = self.pools.lock().unwrap();
66 let mut stats = self.stats.lock().unwrap();
67
68 if let Some(pool) = pools.get_mut(&size) {
69 if let Some(array) = pool.pop() {
70 stats.hits += 1;
71 stats.current_bytes -= size * std::mem::size_of::<f64>();
72 return array;
73 }
74 }
75
76 stats.misses += 1;
77 vec![0.0; size]
78 }
79
80 pub fn return_array(&self, mut array: Vec<f64>) {
82 let size = array.len();
83 let mut pools = self.pools.lock().unwrap();
84 let mut stats = self.stats.lock().unwrap();
85
86 let pool = pools.entry(size).or_default();
87 if pool.len() < self.max_pool_size {
88 array.fill(0.0);
90 pool.push(array);
91 stats.returns += 1;
92 stats.current_bytes += size * std::mem::size_of::<f64>();
93 stats.peak_bytes = stats.peak_bytes.max(stats.current_bytes);
94 }
95 }
97
98 pub fn stats(&self) -> PoolStats {
100 self.stats.lock().unwrap().clone()
101 }
102
103 pub fn clear(&self) {
105 let mut pools = self.pools.lock().unwrap();
106 let mut stats = self.stats.lock().unwrap();
107 pools.clear();
108 stats.current_bytes = 0;
109 }
110
111 pub fn hit_rate(&self) -> f64 {
113 let stats = self.stats.lock().unwrap();
114 let total = stats.hits + stats.misses;
115 if total > 0 {
116 stats.hits as f64 / total as f64
117 } else {
118 0.0
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
127pub struct SparseFactor {
128 pub variables: Vec<String>,
130 pub cardinalities: Vec<usize>,
132 pub entries: Vec<(Vec<usize>, f64)>,
134 pub default_value: f64,
136}
137
138impl SparseFactor {
139 pub fn new(variables: Vec<String>, cardinalities: Vec<usize>) -> Self {
141 Self {
142 variables,
143 cardinalities,
144 entries: Vec::new(),
145 default_value: 0.0,
146 }
147 }
148
149 pub fn from_dense(factor: &Factor, threshold: f64) -> Self {
153 let shape: Vec<usize> = factor.values.shape().to_vec();
154 let mut sparse = Self::new(factor.variables.clone(), shape.clone());
155 sparse.default_value = 0.0;
156
157 let total_size: usize = shape.iter().product();
158
159 for i in 0..total_size {
160 let indices = Self::flat_to_indices(i, &shape);
161 let value = factor.values[indices.as_slice()];
162
163 if value.abs() > threshold {
164 sparse.entries.push((indices, value));
165 }
166 }
167
168 sparse
169 }
170
171 pub fn to_dense(&self) -> Result<Factor> {
173 let total_size: usize = self.cardinalities.iter().product();
174 let mut values = vec![self.default_value; total_size];
175
176 for (indices, value) in &self.entries {
177 let flat_idx = Self::indices_to_flat(indices, &self.cardinalities);
178 values[flat_idx] = *value;
179 }
180
181 let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
182
183 Factor::new("sparse".to_string(), self.variables.clone(), array)
184 }
185
186 pub fn get(&self, indices: &[usize]) -> f64 {
188 for (entry_indices, value) in &self.entries {
189 if entry_indices == indices {
190 return *value;
191 }
192 }
193 self.default_value
194 }
195
196 pub fn set(&mut self, indices: Vec<usize>, value: f64) {
198 for (entry_indices, entry_value) in &mut self.entries {
200 if *entry_indices == indices {
201 *entry_value = value;
202 return;
203 }
204 }
205
206 if (value - self.default_value).abs() > 1e-10 {
208 self.entries.push((indices, value));
209 }
210 }
211
212 pub fn sparsity(&self) -> f64 {
214 let total_size: usize = self.cardinalities.iter().product();
215 if total_size > 0 {
216 1.0 - (self.entries.len() as f64 / total_size as f64)
217 } else {
218 1.0
219 }
220 }
221
222 pub fn memory_savings(&self) -> f64 {
224 let dense_bytes = self.cardinalities.iter().product::<usize>() * std::mem::size_of::<f64>();
225 let sparse_bytes = self.entries.len()
226 * (self.variables.len() * std::mem::size_of::<usize>() + std::mem::size_of::<f64>());
227
228 if dense_bytes > 0 {
229 1.0 - (sparse_bytes as f64 / dense_bytes as f64)
230 } else {
231 0.0
232 }
233 }
234
235 fn flat_to_indices(flat: usize, shape: &[usize]) -> Vec<usize> {
237 let mut indices = vec![0; shape.len()];
238 let mut remaining = flat;
239
240 for i in (0..shape.len()).rev() {
241 indices[i] = remaining % shape[i];
242 remaining /= shape[i];
243 }
244
245 indices
246 }
247
248 fn indices_to_flat(indices: &[usize], shape: &[usize]) -> usize {
250 let mut flat = 0;
251 let mut stride = 1;
252
253 for i in (0..shape.len()).rev() {
254 flat += indices[i] * stride;
255 stride *= shape[i];
256 }
257
258 flat
259 }
260}
261
262#[derive(Clone)]
266pub struct LazyFactor {
267 computation: Arc<dyn Fn() -> Result<Factor> + Send + Sync>,
269 cached: Arc<Mutex<Option<Factor>>>,
271}
272
273impl std::fmt::Debug for LazyFactor {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("LazyFactor")
276 .field("cached", &self.cached.lock().unwrap().is_some())
277 .finish()
278 }
279}
280
281impl LazyFactor {
282 pub fn new<F>(computation: F) -> Self
284 where
285 F: Fn() -> Result<Factor> + Send + Sync + 'static,
286 {
287 Self {
288 computation: Arc::new(computation),
289 cached: Arc::new(Mutex::new(None)),
290 }
291 }
292
293 pub fn from_factor(factor: Factor) -> Self {
295 Self {
296 computation: Arc::new(move || {
297 Err(PgmError::InvalidDistribution(
298 "Already computed".to_string(),
299 ))
300 }),
301 cached: Arc::new(Mutex::new(Some(factor))),
302 }
303 }
304
305 pub fn evaluate(&self) -> Result<Factor> {
307 let mut cached = self.cached.lock().unwrap();
308
309 if let Some(ref factor) = *cached {
310 return Ok(factor.clone());
311 }
312
313 let result = (self.computation)()?;
314 *cached = Some(result.clone());
315 Ok(result)
316 }
317
318 pub fn is_computed(&self) -> bool {
320 self.cached.lock().unwrap().is_some()
321 }
322
323 pub fn clear_cache(&self) {
325 let mut cached = self.cached.lock().unwrap();
326 *cached = None;
327 }
328
329 pub fn lazy_product(a: LazyFactor, b: LazyFactor) -> LazyFactor {
331 LazyFactor::new(move || {
332 let factor_a = a.evaluate()?;
333 let factor_b = b.evaluate()?;
334 factor_a.product(&factor_b)
335 })
336 }
337
338 pub fn lazy_marginalize(factor: LazyFactor, var: String) -> LazyFactor {
340 LazyFactor::new(move || {
341 let f = factor.evaluate()?;
342 f.marginalize_out(&var)
343 })
344 }
345}
346
347pub struct StreamingFactorGraph {
351 variables: HashMap<String, VariableInfo>,
353 factor_generators: Vec<Box<dyn Fn() -> Result<Factor> + Send + Sync>>,
355 #[allow(dead_code)]
357 pool: Arc<FactorPool>,
358}
359
360#[derive(Debug, Clone)]
362#[allow(dead_code)]
363struct VariableInfo {
364 domain: String,
365 cardinality: usize,
366}
367
368impl StreamingFactorGraph {
369 pub fn new() -> Self {
371 Self {
372 variables: HashMap::new(),
373 factor_generators: Vec::new(),
374 pool: Arc::new(FactorPool::default()),
375 }
376 }
377
378 pub fn with_pool(pool: Arc<FactorPool>) -> Self {
380 Self {
381 variables: HashMap::new(),
382 factor_generators: Vec::new(),
383 pool,
384 }
385 }
386
387 pub fn add_variable(&mut self, name: String, domain: String, cardinality: usize) {
389 self.variables.insert(
390 name,
391 VariableInfo {
392 domain,
393 cardinality,
394 },
395 );
396 }
397
398 pub fn add_factor<F>(&mut self, generator: F)
400 where
401 F: Fn() -> Result<Factor> + Send + Sync + 'static,
402 {
403 self.factor_generators.push(Box::new(generator));
404 }
405
406 pub fn stream_factors(&self) -> impl Iterator<Item = Result<Factor>> + '_ {
408 self.factor_generators.iter().map(|gen| gen())
409 }
410
411 pub fn streaming_product(&self) -> Result<Factor> {
415 let mut result: Option<Factor> = None;
416
417 for gen in &self.factor_generators {
418 let factor = gen()?;
419
420 result = match result {
421 None => Some(factor),
422 Some(r) => Some(r.product(&factor)?),
423 };
424 }
425
426 result.ok_or_else(|| PgmError::InvalidDistribution("No factors in graph".to_string()))
427 }
428
429 pub fn num_variables(&self) -> usize {
431 self.variables.len()
432 }
433
434 pub fn num_factors(&self) -> usize {
436 self.factor_generators.len()
437 }
438}
439
440impl Default for StreamingFactorGraph {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446#[derive(Debug, Clone)]
450pub struct CompressedFactor {
451 pub variables: Vec<String>,
453 pub cardinalities: Vec<usize>,
455 quantized: Vec<u16>,
457 min_value: f64,
459 scale: f64,
461}
462
463impl CompressedFactor {
464 pub fn from_factor(factor: &Factor) -> Self {
466 let values: Vec<f64> = factor.values.iter().copied().collect();
467 let cardinalities: Vec<usize> = factor.values.shape().to_vec();
468
469 let min_value = values.iter().copied().fold(f64::INFINITY, f64::min);
470 let max_value = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
471
472 let scale = if max_value > min_value {
473 (max_value - min_value) / 65535.0
474 } else {
475 1.0
476 };
477
478 let quantized: Vec<u16> = values
479 .iter()
480 .map(|&v| ((v - min_value) / scale).round() as u16)
481 .collect();
482
483 Self {
484 variables: factor.variables.clone(),
485 cardinalities,
486 quantized,
487 min_value,
488 scale,
489 }
490 }
491
492 pub fn to_factor(&self) -> Result<Factor> {
494 let values: Vec<f64> = self
495 .quantized
496 .iter()
497 .map(|&q| self.min_value + (q as f64) * self.scale)
498 .collect();
499
500 let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
501
502 Factor::new("compressed".to_string(), self.variables.clone(), array)
503 }
504
505 pub fn memory_size(&self) -> usize {
507 self.quantized.len() * std::mem::size_of::<u16>()
508 + self.variables.len() * std::mem::size_of::<String>()
509 + self.cardinalities.len() * std::mem::size_of::<usize>()
510 + 2 * std::mem::size_of::<f64>()
511 }
512
513 pub fn compression_ratio(&self) -> f64 {
515 let original = self.quantized.len() * std::mem::size_of::<f64>();
516 let compressed = self.quantized.len() * std::mem::size_of::<u16>();
517
518 if compressed > 0 {
519 original as f64 / compressed as f64
520 } else {
521 1.0
522 }
523 }
524}
525
526#[derive(Debug, Clone)]
530pub struct BlockSparseFactor {
531 pub variables: Vec<String>,
533 pub cardinalities: Vec<usize>,
535 pub block_size: usize,
537 blocks: HashMap<Vec<usize>, Vec<f64>>,
539 default_value: f64,
541}
542
543impl BlockSparseFactor {
544 pub fn new(variables: Vec<String>, cardinalities: Vec<usize>, block_size: usize) -> Self {
546 Self {
547 variables,
548 cardinalities,
549 block_size,
550 blocks: HashMap::new(),
551 default_value: 0.0,
552 }
553 }
554
555 pub fn from_factor(factor: &Factor, block_size: usize, threshold: f64) -> Self {
557 let shape: Vec<usize> = factor.values.shape().to_vec();
558 let mut sparse = Self::new(factor.variables.clone(), shape.clone(), block_size);
559 sparse.default_value = 0.0;
560 let block_dims: Vec<usize> = shape.iter().map(|&d| d.div_ceil(block_size)).collect();
561
562 let total_blocks: usize = block_dims.iter().product();
564 for block_flat in 0..total_blocks {
565 let block_indices = SparseFactor::flat_to_indices(block_flat, &block_dims);
566
567 let block_total = block_size.pow(shape.len() as u32);
569 let mut block_values = Vec::with_capacity(block_total);
570 let mut has_nonzero = false;
571
572 for local_flat in 0..block_total {
573 let local_indices =
574 SparseFactor::flat_to_indices(local_flat, &vec![block_size; shape.len()]);
575
576 let global_indices: Vec<usize> = block_indices
578 .iter()
579 .zip(local_indices.iter())
580 .zip(shape.iter())
581 .map(|((&bi, &li), &s)| (bi * block_size + li).min(s - 1))
582 .collect();
583
584 let value = factor.values[global_indices.as_slice()];
585 block_values.push(value);
586
587 if value.abs() > threshold {
588 has_nonzero = true;
589 }
590 }
591
592 if has_nonzero {
593 sparse.blocks.insert(block_indices, block_values);
594 }
595 }
596
597 sparse
598 }
599
600 pub fn num_blocks(&self) -> usize {
602 self.blocks.len()
603 }
604
605 pub fn block_sparsity(&self) -> f64 {
607 let block_dims: Vec<usize> = self
608 .cardinalities
609 .iter()
610 .map(|&d| d.div_ceil(self.block_size))
611 .collect();
612 let total_blocks: usize = block_dims.iter().product();
613
614 if total_blocks > 0 {
615 1.0 - (self.blocks.len() as f64 / total_blocks as f64)
616 } else {
617 1.0
618 }
619 }
620}
621
622pub fn estimate_memory_usage(
624 num_variables: usize,
625 avg_cardinality: usize,
626 num_factors: usize,
627 avg_scope_size: usize,
628) -> MemoryEstimate {
629 let bytes_per_entry = std::mem::size_of::<f64>();
630 let avg_factor_size = avg_cardinality.pow(avg_scope_size as u32);
631 let total_factor_bytes = num_factors * avg_factor_size * bytes_per_entry;
632
633 let edges = num_factors * avg_scope_size;
635 let message_bytes = 2 * edges * avg_cardinality * bytes_per_entry;
636
637 let marginal_bytes = num_variables * avg_cardinality * bytes_per_entry;
639
640 MemoryEstimate {
641 factor_bytes: total_factor_bytes,
642 message_bytes,
643 marginal_bytes,
644 total_bytes: total_factor_bytes + message_bytes + marginal_bytes,
645 }
646}
647
648#[derive(Debug, Clone)]
650pub struct MemoryEstimate {
651 pub factor_bytes: usize,
653 pub message_bytes: usize,
655 pub marginal_bytes: usize,
657 pub total_bytes: usize,
659}
660
661impl std::fmt::Display for MemoryEstimate {
662 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
663 let to_mb = |bytes: usize| bytes as f64 / 1_048_576.0;
664 write!(
665 f,
666 "Memory Estimate: {:.2} MB total (factors: {:.2} MB, messages: {:.2} MB, marginals: {:.2} MB)",
667 to_mb(self.total_bytes),
668 to_mb(self.factor_bytes),
669 to_mb(self.message_bytes),
670 to_mb(self.marginal_bytes)
671 )
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use approx::assert_abs_diff_eq;
679 use scirs2_core::ndarray::Array;
680
681 #[test]
682 fn test_factor_pool_allocation() {
683 let pool = FactorPool::new(10);
684
685 let arr1 = pool.allocate(100);
686 assert_eq!(arr1.len(), 100);
687
688 pool.return_array(arr1);
689 assert_eq!(pool.stats().returns, 1);
690
691 let arr2 = pool.allocate(100);
693 assert_eq!(arr2.len(), 100);
694 assert_eq!(pool.stats().hits, 1);
695 }
696
697 #[test]
698 fn test_factor_pool_hit_rate() {
699 let pool = FactorPool::new(10);
700
701 let arr = pool.allocate(50);
703 pool.return_array(arr);
704
705 let _ = pool.allocate(50);
707
708 assert!(pool.hit_rate() > 0.4); }
710
711 #[test]
712 fn test_sparse_factor_creation() {
713 let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![4]);
714
715 sparse.set(vec![0], 1.0);
716 sparse.set(vec![2], 0.5);
717
718 assert_abs_diff_eq!(sparse.get(&[0]), 1.0, epsilon = 1e-10);
719 assert_abs_diff_eq!(sparse.get(&[1]), 0.0, epsilon = 1e-10);
720 assert_abs_diff_eq!(sparse.get(&[2]), 0.5, epsilon = 1e-10);
721 }
722
723 #[test]
724 fn test_sparse_factor_from_dense() {
725 let factor = Factor::new(
726 "test".to_string(),
727 vec!["x".to_string()],
728 Array::from_vec(vec![0.0, 1.0, 0.0, 0.5]).into_dyn(),
729 )
730 .unwrap();
731
732 let sparse = SparseFactor::from_dense(&factor, 0.1);
733 assert_eq!(sparse.entries.len(), 2); let dense = sparse.to_dense().unwrap();
736 assert_abs_diff_eq!(dense.values[[1]], 1.0, epsilon = 1e-10);
737 assert_abs_diff_eq!(dense.values[[3]], 0.5, epsilon = 1e-10);
738 }
739
740 #[test]
741 fn test_sparse_factor_sparsity() {
742 let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![100]);
743 sparse.set(vec![50], 1.0);
744
745 let sparsity = sparse.sparsity();
746 assert!(sparsity > 0.98); }
748
749 #[test]
750 fn test_lazy_factor_deferred() {
751 let counter = Arc::new(Mutex::new(0));
752 let counter_clone = counter.clone();
753
754 let lazy = LazyFactor::new(move || {
755 let mut count = counter_clone.lock().unwrap();
756 *count += 1;
757 Factor::new(
758 "test".to_string(),
759 vec!["x".to_string()],
760 Array::from_vec(vec![0.5, 0.5]).into_dyn(),
761 )
762 });
763
764 assert!(!lazy.is_computed());
765 assert_eq!(*counter.lock().unwrap(), 0);
766
767 let _ = lazy.evaluate().unwrap();
768 assert!(lazy.is_computed());
769 assert_eq!(*counter.lock().unwrap(), 1);
770
771 let _ = lazy.evaluate().unwrap();
773 assert_eq!(*counter.lock().unwrap(), 1);
774 }
775
776 #[test]
777 fn test_lazy_factor_from_factor() {
778 let factor = Factor::new(
779 "test".to_string(),
780 vec!["x".to_string()],
781 Array::from_vec(vec![0.3, 0.7]).into_dyn(),
782 )
783 .unwrap();
784
785 let lazy = LazyFactor::from_factor(factor);
786 assert!(lazy.is_computed());
787
788 let result = lazy.evaluate().unwrap();
789 assert_abs_diff_eq!(result.values[[0]], 0.3, epsilon = 1e-10);
790 }
791
792 #[test]
793 fn test_compressed_factor() {
794 let factor = Factor::new(
795 "test".to_string(),
796 vec!["x".to_string()],
797 Array::from_vec(vec![0.1, 0.2, 0.3, 0.4]).into_dyn(),
798 )
799 .unwrap();
800
801 let compressed = CompressedFactor::from_factor(&factor);
802 let decompressed = compressed.to_factor().unwrap();
803
804 for i in 0..4 {
806 assert_abs_diff_eq!(factor.values[[i]], decompressed.values[[i]], epsilon = 0.01);
807 }
808 }
809
810 #[test]
811 fn test_compressed_factor_ratio() {
812 let factor = Factor::new(
813 "test".to_string(),
814 vec!["x".to_string(), "y".to_string()],
815 ArrayD::from_elem(IxDyn(&[10, 10]), 0.5),
816 )
817 .unwrap();
818
819 let compressed = CompressedFactor::from_factor(&factor);
820 let ratio = compressed.compression_ratio();
821
822 assert!(ratio > 3.5);
824 }
825
826 #[test]
827 fn test_streaming_factor_graph() {
828 let mut graph = StreamingFactorGraph::new();
829 graph.add_variable("x".to_string(), "Binary".to_string(), 2);
830 graph.add_variable("y".to_string(), "Binary".to_string(), 2);
831
832 graph.add_factor(|| {
833 Factor::new(
834 "factor_x".to_string(),
835 vec!["x".to_string()],
836 Array::from_vec(vec![0.3, 0.7]).into_dyn(),
837 )
838 });
839
840 graph.add_factor(|| {
841 Factor::new(
842 "factor_y".to_string(),
843 vec!["y".to_string()],
844 Array::from_vec(vec![0.4, 0.6]).into_dyn(),
845 )
846 });
847
848 assert_eq!(graph.num_variables(), 2);
849 assert_eq!(graph.num_factors(), 2);
850 }
851
852 #[test]
853 fn test_memory_estimate() {
854 let estimate = estimate_memory_usage(10, 3, 20, 3);
855
856 assert!(estimate.total_bytes > 0);
857 assert!(estimate.factor_bytes > 0);
858 assert!(estimate.message_bytes > 0);
859 }
860
861 #[test]
862 fn test_block_sparse_factor() {
863 let factor = Factor::new(
864 "test".to_string(),
865 vec!["x".to_string(), "y".to_string()],
866 ArrayD::from_elem(IxDyn(&[8, 8]), 0.0),
867 )
868 .unwrap();
869
870 let block_sparse = BlockSparseFactor::from_factor(&factor, 4, 0.001);
871
872 let sparsity = block_sparse.block_sparsity();
874 assert!(sparsity > 0.99);
875 }
876}