Skip to main content

tensorlogic_quantrs_hooks/
memory.rs

1//! Memory optimization utilities for large factor graphs.
2//!
3//! This module provides memory-efficient representations and operations for
4//! probabilistic graphical models, including:
5//!
6//! - Memory pooling for factor allocation
7//! - Sparse factor representation for factors with many zeros
8//! - Lazy evaluation for factor operations
9//! - Memory-mapped factors for very large models
10
11use scirs2_core::ndarray::{ArrayD, IxDyn};
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15use crate::error::{PgmError, Result};
16use crate::Factor;
17
18/// Memory pool for factor value arrays.
19///
20/// Reuses allocated arrays to reduce allocation overhead
21/// in iterative algorithms like message passing.
22#[derive(Debug)]
23pub struct FactorPool {
24    /// Pool of available arrays by total size
25    pools: Mutex<HashMap<usize, Vec<Vec<f64>>>>,
26    /// Statistics
27    stats: Mutex<PoolStats>,
28    /// Maximum pool size per dimension
29    max_pool_size: usize,
30}
31
32/// Statistics for memory pool usage.
33#[derive(Debug, Clone, Default)]
34pub struct PoolStats {
35    /// Number of allocations served from pool
36    pub hits: usize,
37    /// Number of new allocations
38    pub misses: usize,
39    /// Number of arrays returned to pool
40    pub returns: usize,
41    /// Peak memory usage in bytes
42    pub peak_bytes: usize,
43    /// Current memory in pool
44    pub current_bytes: usize,
45}
46
47impl Default for FactorPool {
48    fn default() -> Self {
49        Self::new(100)
50    }
51}
52
53impl FactorPool {
54    /// Create a new factor pool with maximum size per dimension.
55    pub fn new(max_pool_size: usize) -> Self {
56        Self {
57            pools: Mutex::new(HashMap::new()),
58            stats: Mutex::new(PoolStats::default()),
59            max_pool_size,
60        }
61    }
62
63    /// Allocate or reuse an array of the given size.
64    pub fn allocate(&self, size: usize) -> Vec<f64> {
65        let mut pools = self.pools.lock().unwrap();
66        let mut stats = self.stats.lock().unwrap();
67
68        if let Some(pool) = pools.get_mut(&size) {
69            if let Some(array) = pool.pop() {
70                stats.hits += 1;
71                stats.current_bytes -= size * std::mem::size_of::<f64>();
72                return array;
73            }
74        }
75
76        stats.misses += 1;
77        vec![0.0; size]
78    }
79
80    /// Return an array to the pool for reuse.
81    pub fn return_array(&self, mut array: Vec<f64>) {
82        let size = array.len();
83        let mut pools = self.pools.lock().unwrap();
84        let mut stats = self.stats.lock().unwrap();
85
86        let pool = pools.entry(size).or_default();
87        if pool.len() < self.max_pool_size {
88            // Clear and return to pool
89            array.fill(0.0);
90            pool.push(array);
91            stats.returns += 1;
92            stats.current_bytes += size * std::mem::size_of::<f64>();
93            stats.peak_bytes = stats.peak_bytes.max(stats.current_bytes);
94        }
95        // Otherwise, let it drop
96    }
97
98    /// Get pool statistics.
99    pub fn stats(&self) -> PoolStats {
100        self.stats.lock().unwrap().clone()
101    }
102
103    /// Clear all pooled arrays.
104    pub fn clear(&self) {
105        let mut pools = self.pools.lock().unwrap();
106        let mut stats = self.stats.lock().unwrap();
107        pools.clear();
108        stats.current_bytes = 0;
109    }
110
111    /// Get hit rate.
112    pub fn hit_rate(&self) -> f64 {
113        let stats = self.stats.lock().unwrap();
114        let total = stats.hits + stats.misses;
115        if total > 0 {
116            stats.hits as f64 / total as f64
117        } else {
118            0.0
119        }
120    }
121}
122
123/// Sparse factor representation using coordinate format.
124///
125/// Efficient for factors where most entries are zero or near-zero.
126#[derive(Debug, Clone)]
127pub struct SparseFactor {
128    /// Variable names
129    pub variables: Vec<String>,
130    /// Cardinalities for each variable
131    pub cardinalities: Vec<usize>,
132    /// Non-zero entries: (indices, value)
133    pub entries: Vec<(Vec<usize>, f64)>,
134    /// Default value for entries not in sparse representation
135    pub default_value: f64,
136}
137
138impl SparseFactor {
139    /// Create a new sparse factor.
140    pub fn new(variables: Vec<String>, cardinalities: Vec<usize>) -> Self {
141        Self {
142            variables,
143            cardinalities,
144            entries: Vec::new(),
145            default_value: 0.0,
146        }
147    }
148
149    /// Create from a dense factor with sparsity threshold.
150    ///
151    /// Values below threshold are treated as zero.
152    pub fn from_dense(factor: &Factor, threshold: f64) -> Self {
153        let shape: Vec<usize> = factor.values.shape().to_vec();
154        let mut sparse = Self::new(factor.variables.clone(), shape.clone());
155        sparse.default_value = 0.0;
156
157        let total_size: usize = shape.iter().product();
158
159        for i in 0..total_size {
160            let indices = Self::flat_to_indices(i, &shape);
161            let value = factor.values[indices.as_slice()];
162
163            if value.abs() > threshold {
164                sparse.entries.push((indices, value));
165            }
166        }
167
168        sparse
169    }
170
171    /// Convert to dense factor.
172    pub fn to_dense(&self) -> Result<Factor> {
173        let total_size: usize = self.cardinalities.iter().product();
174        let mut values = vec![self.default_value; total_size];
175
176        for (indices, value) in &self.entries {
177            let flat_idx = Self::indices_to_flat(indices, &self.cardinalities);
178            values[flat_idx] = *value;
179        }
180
181        let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
182
183        Factor::new("sparse".to_string(), self.variables.clone(), array)
184    }
185
186    /// Get value at indices.
187    pub fn get(&self, indices: &[usize]) -> f64 {
188        for (entry_indices, value) in &self.entries {
189            if entry_indices == indices {
190                return *value;
191            }
192        }
193        self.default_value
194    }
195
196    /// Set value at indices.
197    pub fn set(&mut self, indices: Vec<usize>, value: f64) {
198        // Check if entry exists
199        for (entry_indices, entry_value) in &mut self.entries {
200            if *entry_indices == indices {
201                *entry_value = value;
202                return;
203            }
204        }
205
206        // Add new entry
207        if (value - self.default_value).abs() > 1e-10 {
208            self.entries.push((indices, value));
209        }
210    }
211
212    /// Get sparsity ratio (fraction of non-default entries).
213    pub fn sparsity(&self) -> f64 {
214        let total_size: usize = self.cardinalities.iter().product();
215        if total_size > 0 {
216            1.0 - (self.entries.len() as f64 / total_size as f64)
217        } else {
218            1.0
219        }
220    }
221
222    /// Memory savings compared to dense representation.
223    pub fn memory_savings(&self) -> f64 {
224        let dense_bytes = self.cardinalities.iter().product::<usize>() * std::mem::size_of::<f64>();
225        let sparse_bytes = self.entries.len()
226            * (self.variables.len() * std::mem::size_of::<usize>() + std::mem::size_of::<f64>());
227
228        if dense_bytes > 0 {
229            1.0 - (sparse_bytes as f64 / dense_bytes as f64)
230        } else {
231            0.0
232        }
233    }
234
235    /// Convert flat index to multi-dimensional indices.
236    fn flat_to_indices(flat: usize, shape: &[usize]) -> Vec<usize> {
237        let mut indices = vec![0; shape.len()];
238        let mut remaining = flat;
239
240        for i in (0..shape.len()).rev() {
241            indices[i] = remaining % shape[i];
242            remaining /= shape[i];
243        }
244
245        indices
246    }
247
248    /// Convert multi-dimensional indices to flat index.
249    fn indices_to_flat(indices: &[usize], shape: &[usize]) -> usize {
250        let mut flat = 0;
251        let mut stride = 1;
252
253        for i in (0..shape.len()).rev() {
254            flat += indices[i] * stride;
255            stride *= shape[i];
256        }
257
258        flat
259    }
260}
261
262/// Lazy factor that defers computation until needed.
263///
264/// Useful for chaining operations without intermediate allocations.
265#[derive(Clone)]
266pub struct LazyFactor {
267    /// The computation to perform
268    computation: Arc<dyn Fn() -> Result<Factor> + Send + Sync>,
269    /// Cached result
270    cached: Arc<Mutex<Option<Factor>>>,
271}
272
273impl std::fmt::Debug for LazyFactor {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_struct("LazyFactor")
276            .field("cached", &self.cached.lock().unwrap().is_some())
277            .finish()
278    }
279}
280
281impl LazyFactor {
282    /// Create a new lazy factor from a computation.
283    pub fn new<F>(computation: F) -> Self
284    where
285        F: Fn() -> Result<Factor> + Send + Sync + 'static,
286    {
287        Self {
288            computation: Arc::new(computation),
289            cached: Arc::new(Mutex::new(None)),
290        }
291    }
292
293    /// Create from an already computed factor.
294    pub fn from_factor(factor: Factor) -> Self {
295        Self {
296            computation: Arc::new(move || {
297                Err(PgmError::InvalidDistribution(
298                    "Already computed".to_string(),
299                ))
300            }),
301            cached: Arc::new(Mutex::new(Some(factor))),
302        }
303    }
304
305    /// Evaluate the lazy factor, computing if necessary.
306    pub fn evaluate(&self) -> Result<Factor> {
307        let mut cached = self.cached.lock().unwrap();
308
309        if let Some(ref factor) = *cached {
310            return Ok(factor.clone());
311        }
312
313        let result = (self.computation)()?;
314        *cached = Some(result.clone());
315        Ok(result)
316    }
317
318    /// Check if the factor has been computed.
319    pub fn is_computed(&self) -> bool {
320        self.cached.lock().unwrap().is_some()
321    }
322
323    /// Clear cached result to free memory.
324    pub fn clear_cache(&self) {
325        let mut cached = self.cached.lock().unwrap();
326        *cached = None;
327    }
328
329    /// Create a lazy product of two factors.
330    pub fn lazy_product(a: LazyFactor, b: LazyFactor) -> LazyFactor {
331        LazyFactor::new(move || {
332            let factor_a = a.evaluate()?;
333            let factor_b = b.evaluate()?;
334            factor_a.product(&factor_b)
335        })
336    }
337
338    /// Create a lazy marginalization.
339    pub fn lazy_marginalize(factor: LazyFactor, var: String) -> LazyFactor {
340        LazyFactor::new(move || {
341            let f = factor.evaluate()?;
342            f.marginalize_out(&var)
343        })
344    }
345}
346
347/// Memory-efficient factor graph for very large models.
348///
349/// Uses streaming computation and doesn't hold all factors in memory.
350pub struct StreamingFactorGraph {
351    /// Variable information
352    variables: HashMap<String, VariableInfo>,
353    /// Factor generators (compute on demand)
354    factor_generators: Vec<Box<dyn Fn() -> Result<Factor> + Send + Sync>>,
355    /// Memory pool for allocations (reserved for future use)
356    #[allow(dead_code)]
357    pool: Arc<FactorPool>,
358}
359
360/// Information about a variable.
361#[derive(Debug, Clone)]
362#[allow(dead_code)]
363struct VariableInfo {
364    domain: String,
365    cardinality: usize,
366}
367
368impl StreamingFactorGraph {
369    /// Create a new streaming factor graph.
370    pub fn new() -> Self {
371        Self {
372            variables: HashMap::new(),
373            factor_generators: Vec::new(),
374            pool: Arc::new(FactorPool::default()),
375        }
376    }
377
378    /// Create with a custom memory pool.
379    pub fn with_pool(pool: Arc<FactorPool>) -> Self {
380        Self {
381            variables: HashMap::new(),
382            factor_generators: Vec::new(),
383            pool,
384        }
385    }
386
387    /// Add a variable.
388    pub fn add_variable(&mut self, name: String, domain: String, cardinality: usize) {
389        self.variables.insert(
390            name,
391            VariableInfo {
392                domain,
393                cardinality,
394            },
395        );
396    }
397
398    /// Add a factor generator.
399    pub fn add_factor<F>(&mut self, generator: F)
400    where
401        F: Fn() -> Result<Factor> + Send + Sync + 'static,
402    {
403        self.factor_generators.push(Box::new(generator));
404    }
405
406    /// Stream factors one at a time for memory-efficient processing.
407    pub fn stream_factors(&self) -> impl Iterator<Item = Result<Factor>> + '_ {
408        self.factor_generators.iter().map(|gen| gen())
409    }
410
411    /// Compute the product of all factors using streaming.
412    ///
413    /// Memory efficient but may be slower than batch computation.
414    pub fn streaming_product(&self) -> Result<Factor> {
415        let mut result: Option<Factor> = None;
416
417        for gen in &self.factor_generators {
418            let factor = gen()?;
419
420            result = match result {
421                None => Some(factor),
422                Some(r) => Some(r.product(&factor)?),
423            };
424        }
425
426        result.ok_or_else(|| PgmError::InvalidDistribution("No factors in graph".to_string()))
427    }
428
429    /// Number of variables.
430    pub fn num_variables(&self) -> usize {
431        self.variables.len()
432    }
433
434    /// Number of factors.
435    pub fn num_factors(&self) -> usize {
436        self.factor_generators.len()
437    }
438}
439
440impl Default for StreamingFactorGraph {
441    fn default() -> Self {
442        Self::new()
443    }
444}
445
446/// Compressed factor using quantization.
447///
448/// Reduces memory usage by storing values at lower precision.
449#[derive(Debug, Clone)]
450pub struct CompressedFactor {
451    /// Variable names
452    pub variables: Vec<String>,
453    /// Cardinalities
454    pub cardinalities: Vec<usize>,
455    /// Quantized values (16-bit)
456    quantized: Vec<u16>,
457    /// Minimum value for dequantization
458    min_value: f64,
459    /// Scale for dequantization
460    scale: f64,
461}
462
463impl CompressedFactor {
464    /// Create from a dense factor.
465    pub fn from_factor(factor: &Factor) -> Self {
466        let values: Vec<f64> = factor.values.iter().copied().collect();
467        let cardinalities: Vec<usize> = factor.values.shape().to_vec();
468
469        let min_value = values.iter().copied().fold(f64::INFINITY, f64::min);
470        let max_value = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
471
472        let scale = if max_value > min_value {
473            (max_value - min_value) / 65535.0
474        } else {
475            1.0
476        };
477
478        let quantized: Vec<u16> = values
479            .iter()
480            .map(|&v| ((v - min_value) / scale).round() as u16)
481            .collect();
482
483        Self {
484            variables: factor.variables.clone(),
485            cardinalities,
486            quantized,
487            min_value,
488            scale,
489        }
490    }
491
492    /// Convert back to dense factor.
493    pub fn to_factor(&self) -> Result<Factor> {
494        let values: Vec<f64> = self
495            .quantized
496            .iter()
497            .map(|&q| self.min_value + (q as f64) * self.scale)
498            .collect();
499
500        let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
501
502        Factor::new("compressed".to_string(), self.variables.clone(), array)
503    }
504
505    /// Memory size in bytes.
506    pub fn memory_size(&self) -> usize {
507        self.quantized.len() * std::mem::size_of::<u16>()
508            + self.variables.len() * std::mem::size_of::<String>()
509            + self.cardinalities.len() * std::mem::size_of::<usize>()
510            + 2 * std::mem::size_of::<f64>()
511    }
512
513    /// Compression ratio compared to f64 representation.
514    pub fn compression_ratio(&self) -> f64 {
515        let original = self.quantized.len() * std::mem::size_of::<f64>();
516        let compressed = self.quantized.len() * std::mem::size_of::<u16>();
517
518        if compressed > 0 {
519            original as f64 / compressed as f64
520        } else {
521            1.0
522        }
523    }
524}
525
526/// Block-sparse factor for factors with block structure.
527///
528/// Efficient when non-zero entries are clustered in blocks.
529#[derive(Debug, Clone)]
530pub struct BlockSparseFactor {
531    /// Variable names
532    pub variables: Vec<String>,
533    /// Cardinalities
534    pub cardinalities: Vec<usize>,
535    /// Block size
536    pub block_size: usize,
537    /// Non-zero blocks: (block_index, values)
538    blocks: HashMap<Vec<usize>, Vec<f64>>,
539    /// Default block (all zeros or specific value)
540    default_value: f64,
541}
542
543impl BlockSparseFactor {
544    /// Create a new block-sparse factor.
545    pub fn new(variables: Vec<String>, cardinalities: Vec<usize>, block_size: usize) -> Self {
546        Self {
547            variables,
548            cardinalities,
549            block_size,
550            blocks: HashMap::new(),
551            default_value: 0.0,
552        }
553    }
554
555    /// Create from dense factor with sparsity detection.
556    pub fn from_factor(factor: &Factor, block_size: usize, threshold: f64) -> Self {
557        let shape: Vec<usize> = factor.values.shape().to_vec();
558        let mut sparse = Self::new(factor.variables.clone(), shape.clone(), block_size);
559        sparse.default_value = 0.0;
560        let block_dims: Vec<usize> = shape.iter().map(|&d| d.div_ceil(block_size)).collect();
561
562        // Iterate over blocks
563        let total_blocks: usize = block_dims.iter().product();
564        for block_flat in 0..total_blocks {
565            let block_indices = SparseFactor::flat_to_indices(block_flat, &block_dims);
566
567            // Extract block values
568            let block_total = block_size.pow(shape.len() as u32);
569            let mut block_values = Vec::with_capacity(block_total);
570            let mut has_nonzero = false;
571
572            for local_flat in 0..block_total {
573                let local_indices =
574                    SparseFactor::flat_to_indices(local_flat, &vec![block_size; shape.len()]);
575
576                // Compute global indices
577                let global_indices: Vec<usize> = block_indices
578                    .iter()
579                    .zip(local_indices.iter())
580                    .zip(shape.iter())
581                    .map(|((&bi, &li), &s)| (bi * block_size + li).min(s - 1))
582                    .collect();
583
584                let value = factor.values[global_indices.as_slice()];
585                block_values.push(value);
586
587                if value.abs() > threshold {
588                    has_nonzero = true;
589                }
590            }
591
592            if has_nonzero {
593                sparse.blocks.insert(block_indices, block_values);
594            }
595        }
596
597        sparse
598    }
599
600    /// Get number of non-zero blocks.
601    pub fn num_blocks(&self) -> usize {
602        self.blocks.len()
603    }
604
605    /// Get sparsity at block level.
606    pub fn block_sparsity(&self) -> f64 {
607        let block_dims: Vec<usize> = self
608            .cardinalities
609            .iter()
610            .map(|&d| d.div_ceil(self.block_size))
611            .collect();
612        let total_blocks: usize = block_dims.iter().product();
613
614        if total_blocks > 0 {
615            1.0 - (self.blocks.len() as f64 / total_blocks as f64)
616        } else {
617            1.0
618        }
619    }
620}
621
622/// Estimate memory usage for a factor graph.
623pub fn estimate_memory_usage(
624    num_variables: usize,
625    avg_cardinality: usize,
626    num_factors: usize,
627    avg_scope_size: usize,
628) -> MemoryEstimate {
629    let bytes_per_entry = std::mem::size_of::<f64>();
630    let avg_factor_size = avg_cardinality.pow(avg_scope_size as u32);
631    let total_factor_bytes = num_factors * avg_factor_size * bytes_per_entry;
632
633    // Message storage for belief propagation
634    let edges = num_factors * avg_scope_size;
635    let message_bytes = 2 * edges * avg_cardinality * bytes_per_entry;
636
637    // Variable marginal storage
638    let marginal_bytes = num_variables * avg_cardinality * bytes_per_entry;
639
640    MemoryEstimate {
641        factor_bytes: total_factor_bytes,
642        message_bytes,
643        marginal_bytes,
644        total_bytes: total_factor_bytes + message_bytes + marginal_bytes,
645    }
646}
647
648/// Memory usage estimate.
649#[derive(Debug, Clone)]
650pub struct MemoryEstimate {
651    /// Memory for factor values
652    pub factor_bytes: usize,
653    /// Memory for messages in belief propagation
654    pub message_bytes: usize,
655    /// Memory for marginals
656    pub marginal_bytes: usize,
657    /// Total estimated memory
658    pub total_bytes: usize,
659}
660
661impl std::fmt::Display for MemoryEstimate {
662    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
663        let to_mb = |bytes: usize| bytes as f64 / 1_048_576.0;
664        write!(
665            f,
666            "Memory Estimate: {:.2} MB total (factors: {:.2} MB, messages: {:.2} MB, marginals: {:.2} MB)",
667            to_mb(self.total_bytes),
668            to_mb(self.factor_bytes),
669            to_mb(self.message_bytes),
670            to_mb(self.marginal_bytes)
671        )
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use approx::assert_abs_diff_eq;
679    use scirs2_core::ndarray::Array;
680
681    #[test]
682    fn test_factor_pool_allocation() {
683        let pool = FactorPool::new(10);
684
685        let arr1 = pool.allocate(100);
686        assert_eq!(arr1.len(), 100);
687
688        pool.return_array(arr1);
689        assert_eq!(pool.stats().returns, 1);
690
691        // Should reuse from pool
692        let arr2 = pool.allocate(100);
693        assert_eq!(arr2.len(), 100);
694        assert_eq!(pool.stats().hits, 1);
695    }
696
697    #[test]
698    fn test_factor_pool_hit_rate() {
699        let pool = FactorPool::new(10);
700
701        // First allocation is miss
702        let arr = pool.allocate(50);
703        pool.return_array(arr);
704
705        // Second should be hit
706        let _ = pool.allocate(50);
707
708        assert!(pool.hit_rate() > 0.4); // At least one hit
709    }
710
711    #[test]
712    fn test_sparse_factor_creation() {
713        let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![4]);
714
715        sparse.set(vec![0], 1.0);
716        sparse.set(vec![2], 0.5);
717
718        assert_abs_diff_eq!(sparse.get(&[0]), 1.0, epsilon = 1e-10);
719        assert_abs_diff_eq!(sparse.get(&[1]), 0.0, epsilon = 1e-10);
720        assert_abs_diff_eq!(sparse.get(&[2]), 0.5, epsilon = 1e-10);
721    }
722
723    #[test]
724    fn test_sparse_factor_from_dense() {
725        let factor = Factor::new(
726            "test".to_string(),
727            vec!["x".to_string()],
728            Array::from_vec(vec![0.0, 1.0, 0.0, 0.5]).into_dyn(),
729        )
730        .unwrap();
731
732        let sparse = SparseFactor::from_dense(&factor, 0.1);
733        assert_eq!(sparse.entries.len(), 2); // Only 1.0 and 0.5
734
735        let dense = sparse.to_dense().unwrap();
736        assert_abs_diff_eq!(dense.values[[1]], 1.0, epsilon = 1e-10);
737        assert_abs_diff_eq!(dense.values[[3]], 0.5, epsilon = 1e-10);
738    }
739
740    #[test]
741    fn test_sparse_factor_sparsity() {
742        let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![100]);
743        sparse.set(vec![50], 1.0);
744
745        let sparsity = sparse.sparsity();
746        assert!(sparsity > 0.98); // 99% sparse
747    }
748
749    #[test]
750    fn test_lazy_factor_deferred() {
751        let counter = Arc::new(Mutex::new(0));
752        let counter_clone = counter.clone();
753
754        let lazy = LazyFactor::new(move || {
755            let mut count = counter_clone.lock().unwrap();
756            *count += 1;
757            Factor::new(
758                "test".to_string(),
759                vec!["x".to_string()],
760                Array::from_vec(vec![0.5, 0.5]).into_dyn(),
761            )
762        });
763
764        assert!(!lazy.is_computed());
765        assert_eq!(*counter.lock().unwrap(), 0);
766
767        let _ = lazy.evaluate().unwrap();
768        assert!(lazy.is_computed());
769        assert_eq!(*counter.lock().unwrap(), 1);
770
771        // Second evaluation uses cache
772        let _ = lazy.evaluate().unwrap();
773        assert_eq!(*counter.lock().unwrap(), 1);
774    }
775
776    #[test]
777    fn test_lazy_factor_from_factor() {
778        let factor = Factor::new(
779            "test".to_string(),
780            vec!["x".to_string()],
781            Array::from_vec(vec![0.3, 0.7]).into_dyn(),
782        )
783        .unwrap();
784
785        let lazy = LazyFactor::from_factor(factor);
786        assert!(lazy.is_computed());
787
788        let result = lazy.evaluate().unwrap();
789        assert_abs_diff_eq!(result.values[[0]], 0.3, epsilon = 1e-10);
790    }
791
792    #[test]
793    fn test_compressed_factor() {
794        let factor = Factor::new(
795            "test".to_string(),
796            vec!["x".to_string()],
797            Array::from_vec(vec![0.1, 0.2, 0.3, 0.4]).into_dyn(),
798        )
799        .unwrap();
800
801        let compressed = CompressedFactor::from_factor(&factor);
802        let decompressed = compressed.to_factor().unwrap();
803
804        // Values should be approximately preserved
805        for i in 0..4 {
806            assert_abs_diff_eq!(factor.values[[i]], decompressed.values[[i]], epsilon = 0.01);
807        }
808    }
809
810    #[test]
811    fn test_compressed_factor_ratio() {
812        let factor = Factor::new(
813            "test".to_string(),
814            vec!["x".to_string(), "y".to_string()],
815            ArrayD::from_elem(IxDyn(&[10, 10]), 0.5),
816        )
817        .unwrap();
818
819        let compressed = CompressedFactor::from_factor(&factor);
820        let ratio = compressed.compression_ratio();
821
822        // 16-bit vs 64-bit should give ~4x compression
823        assert!(ratio > 3.5);
824    }
825
826    #[test]
827    fn test_streaming_factor_graph() {
828        let mut graph = StreamingFactorGraph::new();
829        graph.add_variable("x".to_string(), "Binary".to_string(), 2);
830        graph.add_variable("y".to_string(), "Binary".to_string(), 2);
831
832        graph.add_factor(|| {
833            Factor::new(
834                "factor_x".to_string(),
835                vec!["x".to_string()],
836                Array::from_vec(vec![0.3, 0.7]).into_dyn(),
837            )
838        });
839
840        graph.add_factor(|| {
841            Factor::new(
842                "factor_y".to_string(),
843                vec!["y".to_string()],
844                Array::from_vec(vec![0.4, 0.6]).into_dyn(),
845            )
846        });
847
848        assert_eq!(graph.num_variables(), 2);
849        assert_eq!(graph.num_factors(), 2);
850    }
851
852    #[test]
853    fn test_memory_estimate() {
854        let estimate = estimate_memory_usage(10, 3, 20, 3);
855
856        assert!(estimate.total_bytes > 0);
857        assert!(estimate.factor_bytes > 0);
858        assert!(estimate.message_bytes > 0);
859    }
860
861    #[test]
862    fn test_block_sparse_factor() {
863        let factor = Factor::new(
864            "test".to_string(),
865            vec!["x".to_string(), "y".to_string()],
866            ArrayD::from_elem(IxDyn(&[8, 8]), 0.0),
867        )
868        .unwrap();
869
870        let block_sparse = BlockSparseFactor::from_factor(&factor, 4, 0.001);
871
872        // All zeros should give high block sparsity
873        let sparsity = block_sparse.block_sparsity();
874        assert!(sparsity > 0.99);
875    }
876}