Skip to main content

tensorlogic_quantrs_hooks/
memory.rs

1//! Memory optimization utilities for large factor graphs.
2//!
3//! This module provides memory-efficient representations and operations for
4//! probabilistic graphical models, including:
5//!
6//! - Memory pooling for factor allocation
7//! - Sparse factor representation for factors with many zeros
8//! - Lazy evaluation for factor operations
9//! - Memory-mapped factors for very large models
10
11use scirs2_core::ndarray::{ArrayD, IxDyn};
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15use crate::error::{PgmError, Result};
16use crate::Factor;
17
18/// Memory pool for factor value arrays.
19///
20/// Reuses allocated arrays to reduce allocation overhead
21/// in iterative algorithms like message passing.
22#[derive(Debug)]
23pub struct FactorPool {
24    /// Pool of available arrays by total size
25    pools: Mutex<HashMap<usize, Vec<Vec<f64>>>>,
26    /// Statistics
27    stats: Mutex<PoolStats>,
28    /// Maximum pool size per dimension
29    max_pool_size: usize,
30}
31
32/// Statistics for memory pool usage.
33#[derive(Debug, Clone, Default)]
34pub struct PoolStats {
35    /// Number of allocations served from pool
36    pub hits: usize,
37    /// Number of new allocations
38    pub misses: usize,
39    /// Number of arrays returned to pool
40    pub returns: usize,
41    /// Peak memory usage in bytes
42    pub peak_bytes: usize,
43    /// Current memory in pool
44    pub current_bytes: usize,
45}
46
47impl Default for FactorPool {
48    fn default() -> Self {
49        Self::new(100)
50    }
51}
52
53impl FactorPool {
54    /// Create a new factor pool with maximum size per dimension.
55    pub fn new(max_pool_size: usize) -> Self {
56        Self {
57            pools: Mutex::new(HashMap::new()),
58            stats: Mutex::new(PoolStats::default()),
59            max_pool_size,
60        }
61    }
62
63    /// Allocate or reuse an array of the given size.
64    pub fn allocate(&self, size: usize) -> Vec<f64> {
65        let mut pools = self.pools.lock().expect("lock should not be poisoned");
66        let mut stats = self.stats.lock().expect("lock should not be poisoned");
67
68        if let Some(pool) = pools.get_mut(&size) {
69            if let Some(array) = pool.pop() {
70                stats.hits += 1;
71                stats.current_bytes -= size * std::mem::size_of::<f64>();
72                return array;
73            }
74        }
75
76        stats.misses += 1;
77        vec![0.0; size]
78    }
79
80    /// Return an array to the pool for reuse.
81    pub fn return_array(&self, mut array: Vec<f64>) {
82        let size = array.len();
83        let mut pools = self.pools.lock().expect("lock should not be poisoned");
84        let mut stats = self.stats.lock().expect("lock should not be poisoned");
85
86        let pool = pools.entry(size).or_default();
87        if pool.len() < self.max_pool_size {
88            // Clear and return to pool
89            array.fill(0.0);
90            pool.push(array);
91            stats.returns += 1;
92            stats.current_bytes += size * std::mem::size_of::<f64>();
93            stats.peak_bytes = stats.peak_bytes.max(stats.current_bytes);
94        }
95        // Otherwise, let it drop
96    }
97
98    /// Get pool statistics.
99    pub fn stats(&self) -> PoolStats {
100        self.stats
101            .lock()
102            .expect("lock should not be poisoned")
103            .clone()
104    }
105
106    /// Clear all pooled arrays.
107    pub fn clear(&self) {
108        let mut pools = self.pools.lock().expect("lock should not be poisoned");
109        let mut stats = self.stats.lock().expect("lock should not be poisoned");
110        pools.clear();
111        stats.current_bytes = 0;
112    }
113
114    /// Get hit rate.
115    pub fn hit_rate(&self) -> f64 {
116        let stats = self.stats.lock().expect("lock should not be poisoned");
117        let total = stats.hits + stats.misses;
118        if total > 0 {
119            stats.hits as f64 / total as f64
120        } else {
121            0.0
122        }
123    }
124}
125
126/// Sparse factor representation using coordinate format.
127///
128/// Efficient for factors where most entries are zero or near-zero.
129#[derive(Debug, Clone)]
130pub struct SparseFactor {
131    /// Variable names
132    pub variables: Vec<String>,
133    /// Cardinalities for each variable
134    pub cardinalities: Vec<usize>,
135    /// Non-zero entries: (indices, value)
136    pub entries: Vec<(Vec<usize>, f64)>,
137    /// Default value for entries not in sparse representation
138    pub default_value: f64,
139}
140
141impl SparseFactor {
142    /// Create a new sparse factor.
143    pub fn new(variables: Vec<String>, cardinalities: Vec<usize>) -> Self {
144        Self {
145            variables,
146            cardinalities,
147            entries: Vec::new(),
148            default_value: 0.0,
149        }
150    }
151
152    /// Create from a dense factor with sparsity threshold.
153    ///
154    /// Values below threshold are treated as zero.
155    pub fn from_dense(factor: &Factor, threshold: f64) -> Self {
156        let shape: Vec<usize> = factor.values.shape().to_vec();
157        let mut sparse = Self::new(factor.variables.clone(), shape.clone());
158        sparse.default_value = 0.0;
159
160        let total_size: usize = shape.iter().product();
161
162        for i in 0..total_size {
163            let indices = Self::flat_to_indices(i, &shape);
164            let value = factor.values[indices.as_slice()];
165
166            if value.abs() > threshold {
167                sparse.entries.push((indices, value));
168            }
169        }
170
171        sparse
172    }
173
174    /// Convert to dense factor.
175    pub fn to_dense(&self) -> Result<Factor> {
176        let total_size: usize = self.cardinalities.iter().product();
177        let mut values = vec![self.default_value; total_size];
178
179        for (indices, value) in &self.entries {
180            let flat_idx = Self::indices_to_flat(indices, &self.cardinalities);
181            values[flat_idx] = *value;
182        }
183
184        let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
185
186        Factor::new("sparse".to_string(), self.variables.clone(), array)
187    }
188
189    /// Get value at indices.
190    pub fn get(&self, indices: &[usize]) -> f64 {
191        for (entry_indices, value) in &self.entries {
192            if entry_indices == indices {
193                return *value;
194            }
195        }
196        self.default_value
197    }
198
199    /// Set value at indices.
200    pub fn set(&mut self, indices: Vec<usize>, value: f64) {
201        // Check if entry exists
202        for (entry_indices, entry_value) in &mut self.entries {
203            if *entry_indices == indices {
204                *entry_value = value;
205                return;
206            }
207        }
208
209        // Add new entry
210        if (value - self.default_value).abs() > 1e-10 {
211            self.entries.push((indices, value));
212        }
213    }
214
215    /// Get sparsity ratio (fraction of non-default entries).
216    pub fn sparsity(&self) -> f64 {
217        let total_size: usize = self.cardinalities.iter().product();
218        if total_size > 0 {
219            1.0 - (self.entries.len() as f64 / total_size as f64)
220        } else {
221            1.0
222        }
223    }
224
225    /// Memory savings compared to dense representation.
226    pub fn memory_savings(&self) -> f64 {
227        let dense_bytes = self.cardinalities.iter().product::<usize>() * std::mem::size_of::<f64>();
228        let sparse_bytes = self.entries.len()
229            * (self.variables.len() * std::mem::size_of::<usize>() + std::mem::size_of::<f64>());
230
231        if dense_bytes > 0 {
232            1.0 - (sparse_bytes as f64 / dense_bytes as f64)
233        } else {
234            0.0
235        }
236    }
237
238    /// Convert flat index to multi-dimensional indices.
239    fn flat_to_indices(flat: usize, shape: &[usize]) -> Vec<usize> {
240        let mut indices = vec![0; shape.len()];
241        let mut remaining = flat;
242
243        for i in (0..shape.len()).rev() {
244            indices[i] = remaining % shape[i];
245            remaining /= shape[i];
246        }
247
248        indices
249    }
250
251    /// Convert multi-dimensional indices to flat index.
252    fn indices_to_flat(indices: &[usize], shape: &[usize]) -> usize {
253        let mut flat = 0;
254        let mut stride = 1;
255
256        for i in (0..shape.len()).rev() {
257            flat += indices[i] * stride;
258            stride *= shape[i];
259        }
260
261        flat
262    }
263}
264
265/// Lazy factor that defers computation until needed.
266///
267/// Useful for chaining operations without intermediate allocations.
268#[derive(Clone)]
269pub struct LazyFactor {
270    /// The computation to perform
271    computation: Arc<dyn Fn() -> Result<Factor> + Send + Sync>,
272    /// Cached result
273    cached: Arc<Mutex<Option<Factor>>>,
274}
275
276impl std::fmt::Debug for LazyFactor {
277    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278        f.debug_struct("LazyFactor")
279            .field(
280                "cached",
281                &self
282                    .cached
283                    .lock()
284                    .expect("lock should not be poisoned")
285                    .is_some(),
286            )
287            .finish()
288    }
289}
290
291impl LazyFactor {
292    /// Create a new lazy factor from a computation.
293    pub fn new<F>(computation: F) -> Self
294    where
295        F: Fn() -> Result<Factor> + Send + Sync + 'static,
296    {
297        Self {
298            computation: Arc::new(computation),
299            cached: Arc::new(Mutex::new(None)),
300        }
301    }
302
303    /// Create from an already computed factor.
304    pub fn from_factor(factor: Factor) -> Self {
305        Self {
306            computation: Arc::new(move || {
307                Err(PgmError::InvalidDistribution(
308                    "Already computed".to_string(),
309                ))
310            }),
311            cached: Arc::new(Mutex::new(Some(factor))),
312        }
313    }
314
315    /// Evaluate the lazy factor, computing if necessary.
316    pub fn evaluate(&self) -> Result<Factor> {
317        let mut cached = self.cached.lock().expect("lock should not be poisoned");
318
319        if let Some(ref factor) = *cached {
320            return Ok(factor.clone());
321        }
322
323        let result = (self.computation)()?;
324        *cached = Some(result.clone());
325        Ok(result)
326    }
327
328    /// Check if the factor has been computed.
329    pub fn is_computed(&self) -> bool {
330        self.cached
331            .lock()
332            .expect("lock should not be poisoned")
333            .is_some()
334    }
335
336    /// Clear cached result to free memory.
337    pub fn clear_cache(&self) {
338        let mut cached = self.cached.lock().expect("lock should not be poisoned");
339        *cached = None;
340    }
341
342    /// Create a lazy product of two factors.
343    pub fn lazy_product(a: LazyFactor, b: LazyFactor) -> LazyFactor {
344        LazyFactor::new(move || {
345            let factor_a = a.evaluate()?;
346            let factor_b = b.evaluate()?;
347            factor_a.product(&factor_b)
348        })
349    }
350
351    /// Create a lazy marginalization.
352    pub fn lazy_marginalize(factor: LazyFactor, var: String) -> LazyFactor {
353        LazyFactor::new(move || {
354            let f = factor.evaluate()?;
355            f.marginalize_out(&var)
356        })
357    }
358}
359
360/// Memory-efficient factor graph for very large models.
361///
362/// Uses streaming computation and doesn't hold all factors in memory.
363pub struct StreamingFactorGraph {
364    /// Variable information
365    variables: HashMap<String, VariableInfo>,
366    /// Factor generators (compute on demand)
367    factor_generators: Vec<Box<dyn Fn() -> Result<Factor> + Send + Sync>>,
368    /// Memory pool for allocations (reserved for future use)
369    #[allow(dead_code)]
370    pool: Arc<FactorPool>,
371}
372
373/// Information about a variable.
374#[derive(Debug, Clone)]
375#[allow(dead_code)]
376struct VariableInfo {
377    domain: String,
378    cardinality: usize,
379}
380
381impl StreamingFactorGraph {
382    /// Create a new streaming factor graph.
383    pub fn new() -> Self {
384        Self {
385            variables: HashMap::new(),
386            factor_generators: Vec::new(),
387            pool: Arc::new(FactorPool::default()),
388        }
389    }
390
391    /// Create with a custom memory pool.
392    pub fn with_pool(pool: Arc<FactorPool>) -> Self {
393        Self {
394            variables: HashMap::new(),
395            factor_generators: Vec::new(),
396            pool,
397        }
398    }
399
400    /// Add a variable.
401    pub fn add_variable(&mut self, name: String, domain: String, cardinality: usize) {
402        self.variables.insert(
403            name,
404            VariableInfo {
405                domain,
406                cardinality,
407            },
408        );
409    }
410
411    /// Add a factor generator.
412    pub fn add_factor<F>(&mut self, generator: F)
413    where
414        F: Fn() -> Result<Factor> + Send + Sync + 'static,
415    {
416        self.factor_generators.push(Box::new(generator));
417    }
418
419    /// Stream factors one at a time for memory-efficient processing.
420    pub fn stream_factors(&self) -> impl Iterator<Item = Result<Factor>> + '_ {
421        self.factor_generators.iter().map(|gen| gen())
422    }
423
424    /// Compute the product of all factors using streaming.
425    ///
426    /// Memory efficient but may be slower than batch computation.
427    pub fn streaming_product(&self) -> Result<Factor> {
428        let mut result: Option<Factor> = None;
429
430        for gen in &self.factor_generators {
431            let factor = gen()?;
432
433            result = match result {
434                None => Some(factor),
435                Some(r) => Some(r.product(&factor)?),
436            };
437        }
438
439        result.ok_or_else(|| PgmError::InvalidDistribution("No factors in graph".to_string()))
440    }
441
442    /// Number of variables.
443    pub fn num_variables(&self) -> usize {
444        self.variables.len()
445    }
446
447    /// Number of factors.
448    pub fn num_factors(&self) -> usize {
449        self.factor_generators.len()
450    }
451}
452
453impl Default for StreamingFactorGraph {
454    fn default() -> Self {
455        Self::new()
456    }
457}
458
459/// Compressed factor using quantization.
460///
461/// Reduces memory usage by storing values at lower precision.
462#[derive(Debug, Clone)]
463pub struct CompressedFactor {
464    /// Variable names
465    pub variables: Vec<String>,
466    /// Cardinalities
467    pub cardinalities: Vec<usize>,
468    /// Quantized values (16-bit)
469    quantized: Vec<u16>,
470    /// Minimum value for dequantization
471    min_value: f64,
472    /// Scale for dequantization
473    scale: f64,
474}
475
476impl CompressedFactor {
477    /// Create from a dense factor.
478    pub fn from_factor(factor: &Factor) -> Self {
479        let values: Vec<f64> = factor.values.iter().copied().collect();
480        let cardinalities: Vec<usize> = factor.values.shape().to_vec();
481
482        let min_value = values.iter().copied().fold(f64::INFINITY, f64::min);
483        let max_value = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
484
485        let scale = if max_value > min_value {
486            (max_value - min_value) / 65535.0
487        } else {
488            1.0
489        };
490
491        let quantized: Vec<u16> = values
492            .iter()
493            .map(|&v| ((v - min_value) / scale).round() as u16)
494            .collect();
495
496        Self {
497            variables: factor.variables.clone(),
498            cardinalities,
499            quantized,
500            min_value,
501            scale,
502        }
503    }
504
505    /// Convert back to dense factor.
506    pub fn to_factor(&self) -> Result<Factor> {
507        let values: Vec<f64> = self
508            .quantized
509            .iter()
510            .map(|&q| self.min_value + (q as f64) * self.scale)
511            .collect();
512
513        let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
514
515        Factor::new("compressed".to_string(), self.variables.clone(), array)
516    }
517
518    /// Memory size in bytes.
519    pub fn memory_size(&self) -> usize {
520        self.quantized.len() * std::mem::size_of::<u16>()
521            + self.variables.len() * std::mem::size_of::<String>()
522            + self.cardinalities.len() * std::mem::size_of::<usize>()
523            + 2 * std::mem::size_of::<f64>()
524    }
525
526    /// Compression ratio compared to f64 representation.
527    pub fn compression_ratio(&self) -> f64 {
528        let original = self.quantized.len() * std::mem::size_of::<f64>();
529        let compressed = self.quantized.len() * std::mem::size_of::<u16>();
530
531        if compressed > 0 {
532            original as f64 / compressed as f64
533        } else {
534            1.0
535        }
536    }
537}
538
539/// Block-sparse factor for factors with block structure.
540///
541/// Efficient when non-zero entries are clustered in blocks.
542#[derive(Debug, Clone)]
543pub struct BlockSparseFactor {
544    /// Variable names
545    pub variables: Vec<String>,
546    /// Cardinalities
547    pub cardinalities: Vec<usize>,
548    /// Block size
549    pub block_size: usize,
550    /// Non-zero blocks: (block_index, values)
551    blocks: HashMap<Vec<usize>, Vec<f64>>,
552    /// Default block (all zeros or specific value)
553    default_value: f64,
554}
555
556impl BlockSparseFactor {
557    /// Create a new block-sparse factor.
558    pub fn new(variables: Vec<String>, cardinalities: Vec<usize>, block_size: usize) -> Self {
559        Self {
560            variables,
561            cardinalities,
562            block_size,
563            blocks: HashMap::new(),
564            default_value: 0.0,
565        }
566    }
567
568    /// Create from dense factor with sparsity detection.
569    pub fn from_factor(factor: &Factor, block_size: usize, threshold: f64) -> Self {
570        let shape: Vec<usize> = factor.values.shape().to_vec();
571        let mut sparse = Self::new(factor.variables.clone(), shape.clone(), block_size);
572        sparse.default_value = 0.0;
573        let block_dims: Vec<usize> = shape.iter().map(|&d| d.div_ceil(block_size)).collect();
574
575        // Iterate over blocks
576        let total_blocks: usize = block_dims.iter().product();
577        for block_flat in 0..total_blocks {
578            let block_indices = SparseFactor::flat_to_indices(block_flat, &block_dims);
579
580            // Extract block values
581            let block_total = block_size.pow(shape.len() as u32);
582            let mut block_values = Vec::with_capacity(block_total);
583            let mut has_nonzero = false;
584
585            for local_flat in 0..block_total {
586                let local_indices =
587                    SparseFactor::flat_to_indices(local_flat, &vec![block_size; shape.len()]);
588
589                // Compute global indices
590                let global_indices: Vec<usize> = block_indices
591                    .iter()
592                    .zip(local_indices.iter())
593                    .zip(shape.iter())
594                    .map(|((&bi, &li), &s)| (bi * block_size + li).min(s - 1))
595                    .collect();
596
597                let value = factor.values[global_indices.as_slice()];
598                block_values.push(value);
599
600                if value.abs() > threshold {
601                    has_nonzero = true;
602                }
603            }
604
605            if has_nonzero {
606                sparse.blocks.insert(block_indices, block_values);
607            }
608        }
609
610        sparse
611    }
612
613    /// Get number of non-zero blocks.
614    pub fn num_blocks(&self) -> usize {
615        self.blocks.len()
616    }
617
618    /// Get sparsity at block level.
619    pub fn block_sparsity(&self) -> f64 {
620        let block_dims: Vec<usize> = self
621            .cardinalities
622            .iter()
623            .map(|&d| d.div_ceil(self.block_size))
624            .collect();
625        let total_blocks: usize = block_dims.iter().product();
626
627        if total_blocks > 0 {
628            1.0 - (self.blocks.len() as f64 / total_blocks as f64)
629        } else {
630            1.0
631        }
632    }
633}
634
635/// Estimate memory usage for a factor graph.
636pub fn estimate_memory_usage(
637    num_variables: usize,
638    avg_cardinality: usize,
639    num_factors: usize,
640    avg_scope_size: usize,
641) -> MemoryEstimate {
642    let bytes_per_entry = std::mem::size_of::<f64>();
643    let avg_factor_size = avg_cardinality.pow(avg_scope_size as u32);
644    let total_factor_bytes = num_factors * avg_factor_size * bytes_per_entry;
645
646    // Message storage for belief propagation
647    let edges = num_factors * avg_scope_size;
648    let message_bytes = 2 * edges * avg_cardinality * bytes_per_entry;
649
650    // Variable marginal storage
651    let marginal_bytes = num_variables * avg_cardinality * bytes_per_entry;
652
653    MemoryEstimate {
654        factor_bytes: total_factor_bytes,
655        message_bytes,
656        marginal_bytes,
657        total_bytes: total_factor_bytes + message_bytes + marginal_bytes,
658    }
659}
660
661/// Memory usage estimate.
662#[derive(Debug, Clone)]
663pub struct MemoryEstimate {
664    /// Memory for factor values
665    pub factor_bytes: usize,
666    /// Memory for messages in belief propagation
667    pub message_bytes: usize,
668    /// Memory for marginals
669    pub marginal_bytes: usize,
670    /// Total estimated memory
671    pub total_bytes: usize,
672}
673
674impl std::fmt::Display for MemoryEstimate {
675    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
676        let to_mb = |bytes: usize| bytes as f64 / 1_048_576.0;
677        write!(
678            f,
679            "Memory Estimate: {:.2} MB total (factors: {:.2} MB, messages: {:.2} MB, marginals: {:.2} MB)",
680            to_mb(self.total_bytes),
681            to_mb(self.factor_bytes),
682            to_mb(self.message_bytes),
683            to_mb(self.marginal_bytes)
684        )
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use approx::assert_abs_diff_eq;
692    use scirs2_core::ndarray::Array;
693
694    #[test]
695    fn test_factor_pool_allocation() {
696        let pool = FactorPool::new(10);
697
698        let arr1 = pool.allocate(100);
699        assert_eq!(arr1.len(), 100);
700
701        pool.return_array(arr1);
702        assert_eq!(pool.stats().returns, 1);
703
704        // Should reuse from pool
705        let arr2 = pool.allocate(100);
706        assert_eq!(arr2.len(), 100);
707        assert_eq!(pool.stats().hits, 1);
708    }
709
710    #[test]
711    fn test_factor_pool_hit_rate() {
712        let pool = FactorPool::new(10);
713
714        // First allocation is miss
715        let arr = pool.allocate(50);
716        pool.return_array(arr);
717
718        // Second should be hit
719        let _ = pool.allocate(50);
720
721        assert!(pool.hit_rate() > 0.4); // At least one hit
722    }
723
724    #[test]
725    fn test_sparse_factor_creation() {
726        let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![4]);
727
728        sparse.set(vec![0], 1.0);
729        sparse.set(vec![2], 0.5);
730
731        assert_abs_diff_eq!(sparse.get(&[0]), 1.0, epsilon = 1e-10);
732        assert_abs_diff_eq!(sparse.get(&[1]), 0.0, epsilon = 1e-10);
733        assert_abs_diff_eq!(sparse.get(&[2]), 0.5, epsilon = 1e-10);
734    }
735
736    #[test]
737    fn test_sparse_factor_from_dense() {
738        let factor = Factor::new(
739            "test".to_string(),
740            vec!["x".to_string()],
741            Array::from_vec(vec![0.0, 1.0, 0.0, 0.5]).into_dyn(),
742        )
743        .expect("unwrap");
744
745        let sparse = SparseFactor::from_dense(&factor, 0.1);
746        assert_eq!(sparse.entries.len(), 2); // Only 1.0 and 0.5
747
748        let dense = sparse.to_dense().expect("unwrap");
749        assert_abs_diff_eq!(dense.values[[1]], 1.0, epsilon = 1e-10);
750        assert_abs_diff_eq!(dense.values[[3]], 0.5, epsilon = 1e-10);
751    }
752
753    #[test]
754    fn test_sparse_factor_sparsity() {
755        let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![100]);
756        sparse.set(vec![50], 1.0);
757
758        let sparsity = sparse.sparsity();
759        assert!(sparsity > 0.98); // 99% sparse
760    }
761
762    #[test]
763    fn test_lazy_factor_deferred() {
764        let counter = Arc::new(Mutex::new(0));
765        let counter_clone = counter.clone();
766
767        let lazy = LazyFactor::new(move || {
768            let mut count = counter_clone.lock().expect("unwrap");
769            *count += 1;
770            Factor::new(
771                "test".to_string(),
772                vec!["x".to_string()],
773                Array::from_vec(vec![0.5, 0.5]).into_dyn(),
774            )
775        });
776
777        assert!(!lazy.is_computed());
778        assert_eq!(*counter.lock().expect("unwrap"), 0);
779
780        let _ = lazy.evaluate().expect("unwrap");
781        assert!(lazy.is_computed());
782        assert_eq!(*counter.lock().expect("unwrap"), 1);
783
784        // Second evaluation uses cache
785        let _ = lazy.evaluate().expect("unwrap");
786        assert_eq!(*counter.lock().expect("unwrap"), 1);
787    }
788
789    #[test]
790    fn test_lazy_factor_from_factor() {
791        let factor = Factor::new(
792            "test".to_string(),
793            vec!["x".to_string()],
794            Array::from_vec(vec![0.3, 0.7]).into_dyn(),
795        )
796        .expect("unwrap");
797
798        let lazy = LazyFactor::from_factor(factor);
799        assert!(lazy.is_computed());
800
801        let result = lazy.evaluate().expect("unwrap");
802        assert_abs_diff_eq!(result.values[[0]], 0.3, epsilon = 1e-10);
803    }
804
805    #[test]
806    fn test_compressed_factor() {
807        let factor = Factor::new(
808            "test".to_string(),
809            vec!["x".to_string()],
810            Array::from_vec(vec![0.1, 0.2, 0.3, 0.4]).into_dyn(),
811        )
812        .expect("unwrap");
813
814        let compressed = CompressedFactor::from_factor(&factor);
815        let decompressed = compressed.to_factor().expect("unwrap");
816
817        // Values should be approximately preserved
818        for i in 0..4 {
819            assert_abs_diff_eq!(factor.values[[i]], decompressed.values[[i]], epsilon = 0.01);
820        }
821    }
822
823    #[test]
824    fn test_compressed_factor_ratio() {
825        let factor = Factor::new(
826            "test".to_string(),
827            vec!["x".to_string(), "y".to_string()],
828            ArrayD::from_elem(IxDyn(&[10, 10]), 0.5),
829        )
830        .expect("unwrap");
831
832        let compressed = CompressedFactor::from_factor(&factor);
833        let ratio = compressed.compression_ratio();
834
835        // 16-bit vs 64-bit should give ~4x compression
836        assert!(ratio > 3.5);
837    }
838
839    #[test]
840    fn test_streaming_factor_graph() {
841        let mut graph = StreamingFactorGraph::new();
842        graph.add_variable("x".to_string(), "Binary".to_string(), 2);
843        graph.add_variable("y".to_string(), "Binary".to_string(), 2);
844
845        graph.add_factor(|| {
846            Factor::new(
847                "factor_x".to_string(),
848                vec!["x".to_string()],
849                Array::from_vec(vec![0.3, 0.7]).into_dyn(),
850            )
851        });
852
853        graph.add_factor(|| {
854            Factor::new(
855                "factor_y".to_string(),
856                vec!["y".to_string()],
857                Array::from_vec(vec![0.4, 0.6]).into_dyn(),
858            )
859        });
860
861        assert_eq!(graph.num_variables(), 2);
862        assert_eq!(graph.num_factors(), 2);
863    }
864
865    #[test]
866    fn test_memory_estimate() {
867        let estimate = estimate_memory_usage(10, 3, 20, 3);
868
869        assert!(estimate.total_bytes > 0);
870        assert!(estimate.factor_bytes > 0);
871        assert!(estimate.message_bytes > 0);
872    }
873
874    #[test]
875    fn test_block_sparse_factor() {
876        let factor = Factor::new(
877            "test".to_string(),
878            vec!["x".to_string(), "y".to_string()],
879            ArrayD::from_elem(IxDyn(&[8, 8]), 0.0),
880        )
881        .expect("unwrap");
882
883        let block_sparse = BlockSparseFactor::from_factor(&factor, 4, 0.001);
884
885        // All zeros should give high block sparsity
886        let sparsity = block_sparse.block_sparsity();
887        assert!(sparsity > 0.99);
888    }
889}