Skip to main content

tensorlogic_quantrs_hooks/
graph.rs

1//! Factor graph representation.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8
9/// Variable node in a factor graph.
10#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct VariableNode {
12    /// Variable name
13    pub name: String,
14    /// Domain of the variable
15    pub domain: String,
16    /// Cardinality (number of possible values)
17    pub cardinality: usize,
18}
19
20/// Factor node in a factor graph.
21#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
22pub struct FactorNode {
23    /// Factor ID
24    pub id: String,
25    /// Connected variable names
26    pub variables: Vec<String>,
27}
28
29/// Factor graph representation for PGM.
30#[derive(Clone, Debug)]
31pub struct FactorGraph {
32    /// Variable nodes
33    variables: HashMap<String, VariableNode>,
34    /// Factor nodes
35    factors: HashMap<String, Factor>,
36    /// Adjacency: variable -> connected factors
37    var_to_factors: HashMap<String, Vec<String>>,
38    /// Adjacency: factor -> connected variables
39    factor_to_vars: HashMap<String, Vec<String>>,
40}
41
42impl FactorGraph {
43    /// Create a new empty factor graph.
44    pub fn new() -> Self {
45        Self {
46            variables: HashMap::new(),
47            factors: HashMap::new(),
48            var_to_factors: HashMap::new(),
49            factor_to_vars: HashMap::new(),
50        }
51    }
52
53    /// Add a variable to the graph.
54    pub fn add_variable(&mut self, name: String, domain: String) {
55        let node = VariableNode {
56            name: name.clone(),
57            domain,
58            cardinality: 2, // Default binary
59        };
60        self.variables.insert(name.clone(), node);
61        self.var_to_factors.entry(name).or_default();
62    }
63
64    /// Add a variable with specific cardinality.
65    pub fn add_variable_with_card(&mut self, name: String, domain: String, cardinality: usize) {
66        let node = VariableNode {
67            name: name.clone(),
68            domain,
69            cardinality,
70        };
71        self.variables.insert(name.clone(), node);
72        self.var_to_factors.entry(name).or_default();
73    }
74
75    /// Add a factor to the graph.
76    pub fn add_factor(&mut self, factor: Factor) -> Result<()> {
77        let factor_id = factor.name.clone();
78
79        // Ensure all variables exist
80        for var in &factor.variables {
81            if !self.variables.contains_key(var) {
82                return Err(PgmError::VariableNotFound(var.clone()));
83            }
84        }
85
86        // Update adjacency lists
87        for var in &factor.variables {
88            self.var_to_factors
89                .entry(var.clone())
90                .or_default()
91                .push(factor_id.clone());
92        }
93        self.factor_to_vars
94            .insert(factor_id.clone(), factor.variables.clone());
95
96        self.factors.insert(factor_id, factor);
97        Ok(())
98    }
99
100    /// Add a factor from predicate name and variables.
101    pub fn add_factor_from_predicate(&mut self, name: &str, var_names: &[String]) -> Result<()> {
102        // Create uniform factor
103        let factor = Factor::uniform(name.to_string(), var_names.to_vec(), 2);
104        self.add_factor(factor)
105    }
106
107    /// Get variable node.
108    pub fn get_variable(&self, name: &str) -> Option<&VariableNode> {
109        self.variables.get(name)
110    }
111
112    /// Get factor by ID.
113    pub fn get_factor(&self, id: &str) -> Option<&Factor> {
114        self.factors.get(id)
115    }
116
117    /// Get factor by name.
118    pub fn get_factor_by_name(&self, name: &str) -> Option<&Factor> {
119        // Search for factor with matching name
120        self.factors.values().find(|f| f.name == name)
121    }
122
123    /// Get factors connected to a variable.
124    pub fn get_adjacent_factors(&self, var: &str) -> Option<&Vec<String>> {
125        self.var_to_factors.get(var)
126    }
127
128    /// Get variables connected to a factor.
129    pub fn get_adjacent_variables(&self, factor_id: &str) -> Option<&Vec<String>> {
130        self.factor_to_vars.get(factor_id)
131    }
132
133    /// Get number of variables.
134    pub fn num_variables(&self) -> usize {
135        self.variables.len()
136    }
137
138    /// Get number of factors.
139    pub fn num_factors(&self) -> usize {
140        self.factors.len()
141    }
142
143    /// Check if graph is empty.
144    pub fn is_empty(&self) -> bool {
145        self.variables.is_empty() && self.factors.is_empty()
146    }
147
148    /// Get all variable names.
149    pub fn variable_names(&self) -> impl Iterator<Item = &String> {
150        self.variables.keys()
151    }
152
153    /// Get all factor IDs.
154    pub fn factor_ids(&self) -> impl Iterator<Item = &String> {
155        self.factors.keys()
156    }
157
158    /// Get all variables as an iterator.
159    pub fn variables(&self) -> impl Iterator<Item = (&String, &VariableNode)> {
160        self.variables.iter()
161    }
162
163    /// Get all factors as an iterator.
164    pub fn factors(&self) -> impl Iterator<Item = &Factor> {
165        self.factors.values()
166    }
167
168    /// Get all factors as a vector (for external use).
169    pub fn get_all_factors(&self) -> Vec<&Factor> {
170        self.factors.values().collect()
171    }
172}
173
174impl Default for FactorGraph {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_graph_creation() {
186        let graph = FactorGraph::new();
187        assert!(graph.is_empty());
188    }
189
190    #[test]
191    fn test_add_variables() {
192        let mut graph = FactorGraph::new();
193        graph.add_variable("x".to_string(), "D1".to_string());
194        graph.add_variable("y".to_string(), "D2".to_string());
195
196        assert_eq!(graph.num_variables(), 2);
197        assert!(graph.get_variable("x").is_some());
198    }
199
200    #[test]
201    fn test_add_factor() {
202        let mut graph = FactorGraph::new();
203        graph.add_variable("x".to_string(), "D1".to_string());
204        graph.add_variable("y".to_string(), "D2".to_string());
205
206        let result = graph.add_factor_from_predicate("P", &["x".to_string(), "y".to_string()]);
207        assert!(result.is_ok());
208        assert_eq!(graph.num_factors(), 1);
209    }
210
211    #[test]
212    fn test_adjacency() {
213        let mut graph = FactorGraph::new();
214        graph.add_variable("x".to_string(), "D1".to_string());
215        graph.add_variable("y".to_string(), "D2".to_string());
216        graph
217            .add_factor_from_predicate("P", &["x".to_string(), "y".to_string()])
218            .unwrap();
219
220        let adjacent = graph.get_adjacent_factors("x");
221        assert!(adjacent.is_some());
222        assert_eq!(adjacent.unwrap().len(), 1);
223    }
224}