tensorlogic_quantrs_hooks/
graph.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8
9#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct VariableNode {
12 pub name: String,
14 pub domain: String,
16 pub cardinality: usize,
18}
19
20#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
22pub struct FactorNode {
23 pub id: String,
25 pub variables: Vec<String>,
27}
28
29#[derive(Clone, Debug)]
31pub struct FactorGraph {
32 variables: HashMap<String, VariableNode>,
34 factors: HashMap<String, Factor>,
36 var_to_factors: HashMap<String, Vec<String>>,
38 factor_to_vars: HashMap<String, Vec<String>>,
40}
41
42impl FactorGraph {
43 pub fn new() -> Self {
45 Self {
46 variables: HashMap::new(),
47 factors: HashMap::new(),
48 var_to_factors: HashMap::new(),
49 factor_to_vars: HashMap::new(),
50 }
51 }
52
53 pub fn add_variable(&mut self, name: String, domain: String) {
55 let node = VariableNode {
56 name: name.clone(),
57 domain,
58 cardinality: 2, };
60 self.variables.insert(name.clone(), node);
61 self.var_to_factors.entry(name).or_default();
62 }
63
64 pub fn add_variable_with_card(&mut self, name: String, domain: String, cardinality: usize) {
66 let node = VariableNode {
67 name: name.clone(),
68 domain,
69 cardinality,
70 };
71 self.variables.insert(name.clone(), node);
72 self.var_to_factors.entry(name).or_default();
73 }
74
75 pub fn add_factor(&mut self, factor: Factor) -> Result<()> {
77 let factor_id = factor.name.clone();
78
79 for var in &factor.variables {
81 if !self.variables.contains_key(var) {
82 return Err(PgmError::VariableNotFound(var.clone()));
83 }
84 }
85
86 for var in &factor.variables {
88 self.var_to_factors
89 .entry(var.clone())
90 .or_default()
91 .push(factor_id.clone());
92 }
93 self.factor_to_vars
94 .insert(factor_id.clone(), factor.variables.clone());
95
96 self.factors.insert(factor_id, factor);
97 Ok(())
98 }
99
100 pub fn add_factor_from_predicate(&mut self, name: &str, var_names: &[String]) -> Result<()> {
102 let factor = Factor::uniform(name.to_string(), var_names.to_vec(), 2);
104 self.add_factor(factor)
105 }
106
107 pub fn get_variable(&self, name: &str) -> Option<&VariableNode> {
109 self.variables.get(name)
110 }
111
112 pub fn get_factor(&self, id: &str) -> Option<&Factor> {
114 self.factors.get(id)
115 }
116
117 pub fn get_adjacent_factors(&self, var: &str) -> Option<&Vec<String>> {
119 self.var_to_factors.get(var)
120 }
121
122 pub fn get_adjacent_variables(&self, factor_id: &str) -> Option<&Vec<String>> {
124 self.factor_to_vars.get(factor_id)
125 }
126
127 pub fn num_variables(&self) -> usize {
129 self.variables.len()
130 }
131
132 pub fn num_factors(&self) -> usize {
134 self.factors.len()
135 }
136
137 pub fn is_empty(&self) -> bool {
139 self.variables.is_empty() && self.factors.is_empty()
140 }
141
142 pub fn variable_names(&self) -> impl Iterator<Item = &String> {
144 self.variables.keys()
145 }
146
147 pub fn factor_ids(&self) -> impl Iterator<Item = &String> {
149 self.factors.keys()
150 }
151
152 pub fn variables(&self) -> impl Iterator<Item = (&String, &VariableNode)> {
154 self.variables.iter()
155 }
156
157 pub fn factors(&self) -> impl Iterator<Item = &Factor> {
159 self.factors.values()
160 }
161
162 pub fn get_all_factors(&self) -> Vec<&Factor> {
164 self.factors.values().collect()
165 }
166}
167
168impl Default for FactorGraph {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn test_graph_creation() {
180 let graph = FactorGraph::new();
181 assert!(graph.is_empty());
182 }
183
184 #[test]
185 fn test_add_variables() {
186 let mut graph = FactorGraph::new();
187 graph.add_variable("x".to_string(), "D1".to_string());
188 graph.add_variable("y".to_string(), "D2".to_string());
189
190 assert_eq!(graph.num_variables(), 2);
191 assert!(graph.get_variable("x").is_some());
192 }
193
194 #[test]
195 fn test_add_factor() {
196 let mut graph = FactorGraph::new();
197 graph.add_variable("x".to_string(), "D1".to_string());
198 graph.add_variable("y".to_string(), "D2".to_string());
199
200 let result = graph.add_factor_from_predicate("P", &["x".to_string(), "y".to_string()]);
201 assert!(result.is_ok());
202 assert_eq!(graph.num_factors(), 1);
203 }
204
205 #[test]
206 fn test_adjacency() {
207 let mut graph = FactorGraph::new();
208 graph.add_variable("x".to_string(), "D1".to_string());
209 graph.add_variable("y".to_string(), "D2".to_string());
210 graph
211 .add_factor_from_predicate("P", &["x".to_string(), "y".to_string()])
212 .unwrap();
213
214 let adjacent = graph.get_adjacent_factors("x");
215 assert!(adjacent.is_some());
216 assert_eq!(adjacent.unwrap().len(), 1);
217 }
218}