tensorlogic_quantrs_hooks/
graph.rs

1//! Factor graph representation.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8
9/// Variable node in a factor graph.
10#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct VariableNode {
12    /// Variable name
13    pub name: String,
14    /// Domain of the variable
15    pub domain: String,
16    /// Cardinality (number of possible values)
17    pub cardinality: usize,
18}
19
20/// Factor node in a factor graph.
21#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
22pub struct FactorNode {
23    /// Factor ID
24    pub id: String,
25    /// Connected variable names
26    pub variables: Vec<String>,
27}
28
29/// Factor graph representation for PGM.
30#[derive(Clone, Debug)]
31pub struct FactorGraph {
32    /// Variable nodes
33    variables: HashMap<String, VariableNode>,
34    /// Factor nodes
35    factors: HashMap<String, Factor>,
36    /// Adjacency: variable -> connected factors
37    var_to_factors: HashMap<String, Vec<String>>,
38    /// Adjacency: factor -> connected variables
39    factor_to_vars: HashMap<String, Vec<String>>,
40}
41
42impl FactorGraph {
43    /// Create a new empty factor graph.
44    pub fn new() -> Self {
45        Self {
46            variables: HashMap::new(),
47            factors: HashMap::new(),
48            var_to_factors: HashMap::new(),
49            factor_to_vars: HashMap::new(),
50        }
51    }
52
53    /// Add a variable to the graph.
54    pub fn add_variable(&mut self, name: String, domain: String) {
55        let node = VariableNode {
56            name: name.clone(),
57            domain,
58            cardinality: 2, // Default binary
59        };
60        self.variables.insert(name.clone(), node);
61        self.var_to_factors.entry(name).or_default();
62    }
63
64    /// Add a variable with specific cardinality.
65    pub fn add_variable_with_card(&mut self, name: String, domain: String, cardinality: usize) {
66        let node = VariableNode {
67            name: name.clone(),
68            domain,
69            cardinality,
70        };
71        self.variables.insert(name.clone(), node);
72        self.var_to_factors.entry(name).or_default();
73    }
74
75    /// Add a factor to the graph.
76    pub fn add_factor(&mut self, factor: Factor) -> Result<()> {
77        let factor_id = factor.name.clone();
78
79        // Ensure all variables exist
80        for var in &factor.variables {
81            if !self.variables.contains_key(var) {
82                return Err(PgmError::VariableNotFound(var.clone()));
83            }
84        }
85
86        // Update adjacency lists
87        for var in &factor.variables {
88            self.var_to_factors
89                .entry(var.clone())
90                .or_default()
91                .push(factor_id.clone());
92        }
93        self.factor_to_vars
94            .insert(factor_id.clone(), factor.variables.clone());
95
96        self.factors.insert(factor_id, factor);
97        Ok(())
98    }
99
100    /// Add a factor from predicate name and variables.
101    pub fn add_factor_from_predicate(&mut self, name: &str, var_names: &[String]) -> Result<()> {
102        // Create uniform factor
103        let factor = Factor::uniform(name.to_string(), var_names.to_vec(), 2);
104        self.add_factor(factor)
105    }
106
107    /// Get variable node.
108    pub fn get_variable(&self, name: &str) -> Option<&VariableNode> {
109        self.variables.get(name)
110    }
111
112    /// Get factor.
113    pub fn get_factor(&self, id: &str) -> Option<&Factor> {
114        self.factors.get(id)
115    }
116
117    /// Get factors connected to a variable.
118    pub fn get_adjacent_factors(&self, var: &str) -> Option<&Vec<String>> {
119        self.var_to_factors.get(var)
120    }
121
122    /// Get variables connected to a factor.
123    pub fn get_adjacent_variables(&self, factor_id: &str) -> Option<&Vec<String>> {
124        self.factor_to_vars.get(factor_id)
125    }
126
127    /// Get number of variables.
128    pub fn num_variables(&self) -> usize {
129        self.variables.len()
130    }
131
132    /// Get number of factors.
133    pub fn num_factors(&self) -> usize {
134        self.factors.len()
135    }
136
137    /// Check if graph is empty.
138    pub fn is_empty(&self) -> bool {
139        self.variables.is_empty() && self.factors.is_empty()
140    }
141
142    /// Get all variable names.
143    pub fn variable_names(&self) -> impl Iterator<Item = &String> {
144        self.variables.keys()
145    }
146
147    /// Get all factor IDs.
148    pub fn factor_ids(&self) -> impl Iterator<Item = &String> {
149        self.factors.keys()
150    }
151
152    /// Get all variables as an iterator.
153    pub fn variables(&self) -> impl Iterator<Item = (&String, &VariableNode)> {
154        self.variables.iter()
155    }
156
157    /// Get all factors as an iterator.
158    pub fn factors(&self) -> impl Iterator<Item = &Factor> {
159        self.factors.values()
160    }
161
162    /// Get all factors as a vector (for external use).
163    pub fn get_all_factors(&self) -> Vec<&Factor> {
164        self.factors.values().collect()
165    }
166}
167
168impl Default for FactorGraph {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_graph_creation() {
180        let graph = FactorGraph::new();
181        assert!(graph.is_empty());
182    }
183
184    #[test]
185    fn test_add_variables() {
186        let mut graph = FactorGraph::new();
187        graph.add_variable("x".to_string(), "D1".to_string());
188        graph.add_variable("y".to_string(), "D2".to_string());
189
190        assert_eq!(graph.num_variables(), 2);
191        assert!(graph.get_variable("x").is_some());
192    }
193
194    #[test]
195    fn test_add_factor() {
196        let mut graph = FactorGraph::new();
197        graph.add_variable("x".to_string(), "D1".to_string());
198        graph.add_variable("y".to_string(), "D2".to_string());
199
200        let result = graph.add_factor_from_predicate("P", &["x".to_string(), "y".to_string()]);
201        assert!(result.is_ok());
202        assert_eq!(graph.num_factors(), 1);
203    }
204
205    #[test]
206    fn test_adjacency() {
207        let mut graph = FactorGraph::new();
208        graph.add_variable("x".to_string(), "D1".to_string());
209        graph.add_variable("y".to_string(), "D2".to_string());
210        graph
211            .add_factor_from_predicate("P", &["x".to_string(), "y".to_string()])
212            .unwrap();
213
214        let adjacent = graph.get_adjacent_factors("x");
215        assert!(adjacent.is_some());
216        assert_eq!(adjacent.unwrap().len(), 1);
217    }
218}