tensorlogic_quantrs_hooks/
graph.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8
9#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct VariableNode {
12 pub name: String,
14 pub domain: String,
16 pub cardinality: usize,
18}
19
20#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
22pub struct FactorNode {
23 pub id: String,
25 pub variables: Vec<String>,
27}
28
29#[derive(Clone, Debug)]
31pub struct FactorGraph {
32 variables: HashMap<String, VariableNode>,
34 factors: HashMap<String, Factor>,
36 var_to_factors: HashMap<String, Vec<String>>,
38 factor_to_vars: HashMap<String, Vec<String>>,
40}
41
42impl FactorGraph {
43 pub fn new() -> Self {
45 Self {
46 variables: HashMap::new(),
47 factors: HashMap::new(),
48 var_to_factors: HashMap::new(),
49 factor_to_vars: HashMap::new(),
50 }
51 }
52
53 pub fn add_variable(&mut self, name: String, domain: String) {
55 let node = VariableNode {
56 name: name.clone(),
57 domain,
58 cardinality: 2, };
60 self.variables.insert(name.clone(), node);
61 self.var_to_factors.entry(name).or_default();
62 }
63
64 pub fn add_variable_with_card(&mut self, name: String, domain: String, cardinality: usize) {
66 let node = VariableNode {
67 name: name.clone(),
68 domain,
69 cardinality,
70 };
71 self.variables.insert(name.clone(), node);
72 self.var_to_factors.entry(name).or_default();
73 }
74
75 pub fn add_factor(&mut self, factor: Factor) -> Result<()> {
77 let factor_id = factor.name.clone();
78
79 for var in &factor.variables {
81 if !self.variables.contains_key(var) {
82 return Err(PgmError::VariableNotFound(var.clone()));
83 }
84 }
85
86 for var in &factor.variables {
88 self.var_to_factors
89 .entry(var.clone())
90 .or_default()
91 .push(factor_id.clone());
92 }
93 self.factor_to_vars
94 .insert(factor_id.clone(), factor.variables.clone());
95
96 self.factors.insert(factor_id, factor);
97 Ok(())
98 }
99
100 pub fn add_factor_from_predicate(&mut self, name: &str, var_names: &[String]) -> Result<()> {
102 let factor = Factor::uniform(name.to_string(), var_names.to_vec(), 2);
104 self.add_factor(factor)
105 }
106
107 pub fn get_variable(&self, name: &str) -> Option<&VariableNode> {
109 self.variables.get(name)
110 }
111
112 pub fn get_factor(&self, id: &str) -> Option<&Factor> {
114 self.factors.get(id)
115 }
116
117 pub fn get_factor_by_name(&self, name: &str) -> Option<&Factor> {
119 self.factors.values().find(|f| f.name == name)
121 }
122
123 pub fn get_adjacent_factors(&self, var: &str) -> Option<&Vec<String>> {
125 self.var_to_factors.get(var)
126 }
127
128 pub fn get_adjacent_variables(&self, factor_id: &str) -> Option<&Vec<String>> {
130 self.factor_to_vars.get(factor_id)
131 }
132
133 pub fn num_variables(&self) -> usize {
135 self.variables.len()
136 }
137
138 pub fn num_factors(&self) -> usize {
140 self.factors.len()
141 }
142
143 pub fn is_empty(&self) -> bool {
145 self.variables.is_empty() && self.factors.is_empty()
146 }
147
148 pub fn variable_names(&self) -> impl Iterator<Item = &String> {
150 self.variables.keys()
151 }
152
153 pub fn factor_ids(&self) -> impl Iterator<Item = &String> {
155 self.factors.keys()
156 }
157
158 pub fn variables(&self) -> impl Iterator<Item = (&String, &VariableNode)> {
160 self.variables.iter()
161 }
162
163 pub fn factors(&self) -> impl Iterator<Item = &Factor> {
165 self.factors.values()
166 }
167
168 pub fn get_all_factors(&self) -> Vec<&Factor> {
170 self.factors.values().collect()
171 }
172}
173
174impl Default for FactorGraph {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_graph_creation() {
186 let graph = FactorGraph::new();
187 assert!(graph.is_empty());
188 }
189
190 #[test]
191 fn test_add_variables() {
192 let mut graph = FactorGraph::new();
193 graph.add_variable("x".to_string(), "D1".to_string());
194 graph.add_variable("y".to_string(), "D2".to_string());
195
196 assert_eq!(graph.num_variables(), 2);
197 assert!(graph.get_variable("x").is_some());
198 }
199
200 #[test]
201 fn test_add_factor() {
202 let mut graph = FactorGraph::new();
203 graph.add_variable("x".to_string(), "D1".to_string());
204 graph.add_variable("y".to_string(), "D2".to_string());
205
206 let result = graph.add_factor_from_predicate("P", &["x".to_string(), "y".to_string()]);
207 assert!(result.is_ok());
208 assert_eq!(graph.num_factors(), 1);
209 }
210
211 #[test]
212 fn test_adjacency() {
213 let mut graph = FactorGraph::new();
214 graph.add_variable("x".to_string(), "D1".to_string());
215 graph.add_variable("y".to_string(), "D2".to_string());
216 graph
217 .add_factor_from_predicate("P", &["x".to_string(), "y".to_string()])
218 .unwrap();
219
220 let adjacent = graph.get_adjacent_factors("x");
221 assert!(adjacent.is_some());
222 assert_eq!(adjacent.unwrap().len(), 1);
223 }
224}