tensorlogic_quantrs_hooks/
factor.rs

1//! Factor representation and operations.
2
3use scirs2_core::ndarray::ArrayD;
4use serde::{Deserialize, Serialize};
5
6use crate::error::{PgmError, Result};
7
8/// A factor in a probabilistic graphical model.
9///
10/// Represents a function over a subset of variables: φ(X₁, X₂, ..., Xₖ) → ℝ⁺
11#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
12pub struct Factor {
13    /// Variables this factor depends on
14    pub variables: Vec<String>,
15    /// Probability/potential values
16    pub values: ArrayD<f64>,
17    /// Factor name for debugging
18    pub name: String,
19}
20
21impl Factor {
22    /// Create a new factor.
23    pub fn new(name: String, variables: Vec<String>, values: ArrayD<f64>) -> Result<Self> {
24        // Validate dimensions match number of variables
25        if values.ndim() != variables.len() {
26            return Err(PgmError::DimensionMismatch {
27                expected: vec![variables.len()],
28                got: vec![values.ndim()],
29            });
30        }
31
32        Ok(Self {
33            name,
34            variables,
35            values,
36        })
37    }
38
39    /// Create a uniform factor.
40    pub fn uniform(name: String, variables: Vec<String>, card: usize) -> Self {
41        let shape = vec![card; variables.len()];
42        let values = ArrayD::from_elem(shape, 1.0 / (card.pow(variables.len() as u32) as f64));
43        Self {
44            name,
45            variables,
46            values,
47        }
48    }
49
50    /// Normalize factor to sum to 1.
51    pub fn normalize(&mut self) {
52        let sum: f64 = self.values.iter().sum();
53        if sum > 0.0 {
54            self.values /= sum;
55        }
56    }
57
58    /// Get cardinality of a variable.
59    pub fn get_cardinality(&self, var: &str) -> Option<usize> {
60        self.variables
61            .iter()
62            .position(|v| v == var)
63            .map(|idx| self.values.shape()[idx])
64    }
65}
66
67/// Operations on factors.
68pub enum FactorOp {
69    /// Product of factors
70    Product,
71    /// Sum over variables
72    Marginalize,
73    /// Divide factors
74    Divide,
75}
76
77impl Factor {
78    /// Compute the product of two factors.
79    ///
80    /// φ₁(X₁) * φ₂(X₂) = φ(X₁ ∪ X₂)
81    pub fn product(&self, other: &Factor) -> Result<Factor> {
82        // Find union of variables
83        let mut all_vars = self.variables.clone();
84        for v in &other.variables {
85            if !all_vars.contains(v) {
86                all_vars.push(v.clone());
87            }
88        }
89
90        // Build shape and index mappings
91        let mut shape = Vec::new();
92        let mut self_mapping = Vec::new(); // Maps result dims to self dims
93        let mut other_mapping = Vec::new(); // Maps result dims to other dims
94
95        for var in &all_vars {
96            // Find the variable in both factors
97            let self_idx_opt = self.variables.iter().position(|v| v == var);
98            let other_idx_opt = other.variables.iter().position(|v| v == var);
99
100            let cardinality = if let Some(self_idx) = self_idx_opt {
101                self_mapping.push(Some(self_idx));
102                self.values.shape()[self_idx]
103            } else if let Some(other_idx) = other_idx_opt {
104                self_mapping.push(None);
105                other.values.shape()[other_idx]
106            } else {
107                unreachable!("Variable must be in at least one factor");
108            };
109
110            if let Some(other_idx) = other_idx_opt {
111                other_mapping.push(Some(other_idx));
112            } else {
113                other_mapping.push(None);
114            }
115
116            shape.push(cardinality);
117        }
118
119        // Compute product
120        let mut result_values = ArrayD::zeros(shape.clone());
121        let total_size: usize = shape.iter().product();
122
123        for linear_idx in 0..total_size {
124            // Convert linear index to multi-dimensional assignment
125            let mut assignment = Vec::new();
126            let mut temp_idx = linear_idx;
127            for &dim in shape.iter().rev() {
128                assignment.push(temp_idx % dim);
129                temp_idx /= dim;
130            }
131            assignment.reverse();
132
133            // Map to indices for self and other
134            let self_idx: Vec<usize> = self_mapping
135                .iter()
136                .enumerate()
137                .filter_map(|(i, &opt)| opt.map(|_| assignment[i]))
138                .collect();
139
140            let other_idx: Vec<usize> = other_mapping
141                .iter()
142                .enumerate()
143                .filter_map(|(i, &opt)| opt.map(|_| assignment[i]))
144                .collect();
145
146            // Get values
147            let self_val = if self_idx.len() == self.variables.len() {
148                self.values[self_idx.as_slice()]
149            } else {
150                1.0
151            };
152
153            let other_val = if other_idx.len() == other.variables.len() {
154                other.values[other_idx.as_slice()]
155            } else {
156                1.0
157            };
158
159            result_values[assignment.as_slice()] = self_val * other_val;
160        }
161
162        Ok(Factor {
163            name: format!("{}*{}", self.name, other.name),
164            variables: all_vars,
165            values: result_values,
166        })
167    }
168
169    /// Marginalize out a variable by summing over it.
170    ///
171    /// ∑ₓ φ(X, Y) = φ(Y)
172    pub fn marginalize_out(&self, var: &str) -> Result<Factor> {
173        use scirs2_core::ndarray::Axis;
174
175        // Find variable index
176        let var_idx = self
177            .variables
178            .iter()
179            .position(|v| v == var)
180            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
181
182        // Sum over the axis
183        let new_values = self.values.sum_axis(Axis(var_idx));
184
185        // Remove variable from list
186        let new_vars: Vec<String> = self
187            .variables
188            .iter()
189            .filter(|v| *v != var)
190            .cloned()
191            .collect();
192
193        Ok(Factor {
194            name: format!("{}_marg", self.name),
195            variables: new_vars,
196            values: new_values,
197        })
198    }
199
200    /// Marginalize out multiple variables.
201    pub fn marginalize_out_vars(&self, vars: &[String]) -> Result<Factor> {
202        let mut result = self.clone();
203        for var in vars {
204            result = result.marginalize_out(var)?;
205        }
206        Ok(result)
207    }
208
209    /// Marginalize out all variables except the specified ones.
210    ///
211    /// This is useful for extracting marginals: to get P(X), marginalize out all variables except X.
212    pub fn marginalize_out_all_except(&self, keep_vars: &[String]) -> Result<Factor> {
213        let vars_to_remove: Vec<String> = self
214            .variables
215            .iter()
216            .filter(|v| !keep_vars.contains(v))
217            .cloned()
218            .collect();
219
220        self.marginalize_out_vars(&vars_to_remove)
221    }
222
223    /// Maximize out a variable (for max-product algorithm).
224    ///
225    /// max_x φ(X, Y) = φ(Y) where φ(Y) = max over X
226    pub fn maximize_out(&self, var: &str) -> Result<Factor> {
227        use scirs2_core::ndarray::Axis;
228
229        // Find variable index
230        let var_idx = self
231            .variables
232            .iter()
233            .position(|v| v == var)
234            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
235
236        // Take max over the axis
237        let new_values = self.values.map_axis(Axis(var_idx), |view| {
238            view.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
239        });
240
241        // Remove variable from list
242        let new_vars: Vec<String> = self
243            .variables
244            .iter()
245            .filter(|v| *v != var)
246            .cloned()
247            .collect();
248
249        Ok(Factor {
250            name: format!("{}_max", self.name),
251            variables: new_vars,
252            values: new_values,
253        })
254    }
255
256    /// Maximize out multiple variables.
257    pub fn maximize_out_vars(&self, vars: &[String]) -> Result<Factor> {
258        let mut result = self.clone();
259        for var in vars {
260            result = result.maximize_out(var)?;
261        }
262        Ok(result)
263    }
264
265    /// Divide this factor by another factor.
266    ///
267    /// φ₁(X) / φ₂(X) - used for message division
268    pub fn divide(&self, other: &Factor) -> Result<Factor> {
269        // Variables must match
270        if self.variables != other.variables {
271            return Err(PgmError::InvalidDistribution(
272                "Cannot divide factors with different variables".to_string(),
273            ));
274        }
275
276        // Perform element-wise division with safeguard
277        let result_values = &self.values
278            / &other
279                .values
280                .mapv(|x| if x.abs() < 1e-10 { 1e-10 } else { x });
281
282        Ok(Factor {
283            name: format!("{}/{}", self.name, other.name),
284            variables: self.variables.clone(),
285            values: result_values,
286        })
287    }
288
289    /// Reduce factor to specific variable assignment (evidence).
290    pub fn reduce(&self, var: &str, value: usize) -> Result<Factor> {
291        use scirs2_core::ndarray::Axis;
292
293        let var_idx = self
294            .variables
295            .iter()
296            .position(|v| v == var)
297            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
298
299        // Check bounds
300        if value >= self.values.shape()[var_idx] {
301            return Err(PgmError::InvalidDistribution(format!(
302                "Value {} out of bounds for variable {} with cardinality {}",
303                value,
304                var,
305                self.values.shape()[var_idx]
306            )));
307        }
308
309        // Slice at the given value
310        let new_values = self.values.index_axis(Axis(var_idx), value).to_owned();
311
312        // Remove variable
313        let new_vars: Vec<String> = self
314            .variables
315            .iter()
316            .filter(|v| *v != var)
317            .cloned()
318            .collect();
319
320        Ok(Factor {
321            name: format!("{}_reduced", self.name),
322            variables: new_vars,
323            values: new_values,
324        })
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use scirs2_core::ndarray::Array;
332
333    #[test]
334    fn test_factor_creation() {
335        let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
336            .unwrap()
337            .into_dyn();
338        let factor = Factor::new(
339            "f1".to_string(),
340            vec!["x".to_string(), "y".to_string()],
341            values,
342        )
343        .unwrap();
344
345        assert_eq!(factor.variables.len(), 2);
346        assert_eq!(factor.values.ndim(), 2);
347    }
348
349    #[test]
350    fn test_factor_normalize() {
351        let values = Array::from_shape_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0])
352            .unwrap()
353            .into_dyn();
354        let mut factor = Factor::new(
355            "f1".to_string(),
356            vec!["x".to_string(), "y".to_string()],
357            values,
358        )
359        .unwrap();
360
361        factor.normalize();
362        let sum: f64 = factor.values.iter().sum();
363        assert!((sum - 1.0).abs() < 1e-10);
364    }
365
366    #[test]
367    fn test_uniform_factor() {
368        let factor = Factor::uniform("f1".to_string(), vec!["x".to_string()], 3);
369        assert_eq!(factor.values.len(), 3);
370        let sum: f64 = factor.values.iter().sum();
371        assert!((sum - 1.0).abs() < 1e-10);
372    }
373
374    #[test]
375    fn test_factor_product() {
376        // φ₁(X) and φ₂(Y) → φ(X,Y)
377        let f1_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
378            .unwrap()
379            .into_dyn();
380        let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], f1_values).unwrap();
381
382        let f2_values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
383            .unwrap()
384            .into_dyn();
385        let f2 = Factor::new("f2".to_string(), vec!["y".to_string()], f2_values).unwrap();
386
387        let product = f1.product(&f2).unwrap();
388        assert_eq!(product.variables.len(), 2);
389        assert_eq!(product.values.shape(), &[2, 2]);
390
391        // Check values: [0.6*0.7, 0.6*0.3, 0.4*0.7, 0.4*0.3]
392        let expected = 0.6 * 0.7 + 0.6 * 0.3 + 0.4 * 0.7 + 0.4 * 0.3;
393        let actual: f64 = product.values.iter().sum();
394        assert!((actual - expected).abs() < 1e-10);
395    }
396
397    #[test]
398    fn test_factor_marginalize() {
399        // φ(X,Y) → φ(X)
400        let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
401            .unwrap()
402            .into_dyn();
403        let factor = Factor::new(
404            "f1".to_string(),
405            vec!["x".to_string(), "y".to_string()],
406            values,
407        )
408        .unwrap();
409
410        let marginal = factor.marginalize_out("y").unwrap();
411        assert_eq!(marginal.variables.len(), 1);
412        assert_eq!(marginal.variables[0], "x");
413        assert_eq!(marginal.values.shape(), &[2]);
414
415        // Sum over Y: [0.1+0.2, 0.3+0.4] = [0.3, 0.7]
416        assert!((marginal.values[[0]] - 0.3).abs() < 1e-10);
417        assert!((marginal.values[[1]] - 0.7).abs() < 1e-10);
418    }
419
420    #[test]
421    fn test_factor_divide() {
422        let values1 = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
423            .unwrap()
424            .into_dyn();
425        let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], values1).unwrap();
426
427        let values2 = Array::from_shape_vec(vec![2], vec![0.3, 0.2])
428            .unwrap()
429            .into_dyn();
430        let f2 = Factor::new("f2".to_string(), vec!["x".to_string()], values2).unwrap();
431
432        let result = f1.divide(&f2).unwrap();
433        assert_eq!(result.variables.len(), 1);
434
435        // 0.6/0.3 = 2.0, 0.4/0.2 = 2.0
436        assert!((result.values[[0]] - 2.0).abs() < 1e-10);
437        assert!((result.values[[1]] - 2.0).abs() < 1e-10);
438    }
439
440    #[test]
441    fn test_factor_reduce() {
442        // φ(X,Y) with evidence Y=1 → φ(X)
443        let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
444            .unwrap()
445            .into_dyn();
446        let factor = Factor::new(
447            "f1".to_string(),
448            vec!["x".to_string(), "y".to_string()],
449            values,
450        )
451        .unwrap();
452
453        let reduced = factor.reduce("y", 1).unwrap();
454        assert_eq!(reduced.variables.len(), 1);
455        assert_eq!(reduced.variables[0], "x");
456
457        // Y=1 slice: [0.2, 0.4]
458        assert!((reduced.values[[0]] - 0.2).abs() < 1e-10);
459        assert!((reduced.values[[1]] - 0.4).abs() < 1e-10);
460    }
461
462    #[test]
463    fn test_factor_product_with_shared_vars() {
464        // φ₁(X,Y) and φ₂(Y,Z) → φ(X,Y,Z)
465        let f1_values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
466            .unwrap()
467            .into_dyn();
468        let f1 = Factor::new(
469            "f1".to_string(),
470            vec!["x".to_string(), "y".to_string()],
471            f1_values,
472        )
473        .unwrap();
474
475        let f2_values = Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.5, 0.5])
476            .unwrap()
477            .into_dyn();
478        let f2 = Factor::new(
479            "f2".to_string(),
480            vec!["y".to_string(), "z".to_string()],
481            f2_values,
482        )
483        .unwrap();
484
485        let product = f1.product(&f2).unwrap();
486        assert_eq!(product.variables.len(), 3);
487        assert!(product.variables.contains(&"x".to_string()));
488        assert!(product.variables.contains(&"y".to_string()));
489        assert!(product.variables.contains(&"z".to_string()));
490    }
491
492    #[test]
493    fn test_factor_maximize() {
494        // φ(X,Y) → max_Y φ(X)
495        let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
496            .unwrap()
497            .into_dyn();
498        let factor = Factor::new(
499            "f1".to_string(),
500            vec!["x".to_string(), "y".to_string()],
501            values,
502        )
503        .unwrap();
504
505        let maximized = factor.maximize_out("y").unwrap();
506        assert_eq!(maximized.variables.len(), 1);
507        assert_eq!(maximized.variables[0], "x");
508        assert_eq!(maximized.values.shape(), &[2]);
509
510        // Max over Y: [max(0.1, 0.2), max(0.3, 0.4)] = [0.2, 0.4]
511        assert!((maximized.values[[0]] - 0.2).abs() < 1e-10);
512        assert!((maximized.values[[1]] - 0.4).abs() < 1e-10);
513    }
514
515    #[test]
516    fn test_factor_maximize_multiple() {
517        // φ(X,Y,Z) → max_{Y,Z} φ(X)
518        let values =
519            Array::from_shape_vec(vec![2, 2, 2], vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
520                .unwrap()
521                .into_dyn();
522        let factor = Factor::new(
523            "f1".to_string(),
524            vec!["x".to_string(), "y".to_string(), "z".to_string()],
525            values,
526        )
527        .unwrap();
528
529        let maximized = factor
530            .maximize_out_vars(&["y".to_string(), "z".to_string()])
531            .unwrap();
532        assert_eq!(maximized.variables.len(), 1);
533        assert_eq!(maximized.variables[0], "x");
534    }
535}