1use scirs2_core::ndarray::ArrayD;
4use serde::{Deserialize, Serialize};
5
6use crate::error::{PgmError, Result};
7
8#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
12pub struct Factor {
13 pub variables: Vec<String>,
15 pub values: ArrayD<f64>,
17 pub name: String,
19}
20
21impl Factor {
22 pub fn new(name: String, variables: Vec<String>, values: ArrayD<f64>) -> Result<Self> {
24 if values.ndim() != variables.len() {
26 return Err(PgmError::DimensionMismatch {
27 expected: vec![variables.len()],
28 got: vec![values.ndim()],
29 });
30 }
31
32 Ok(Self {
33 name,
34 variables,
35 values,
36 })
37 }
38
39 pub fn uniform(name: String, variables: Vec<String>, card: usize) -> Self {
41 let shape = vec![card; variables.len()];
42 let values = ArrayD::from_elem(shape, 1.0 / (card.pow(variables.len() as u32) as f64));
43 Self {
44 name,
45 variables,
46 values,
47 }
48 }
49
50 pub fn normalize(&mut self) {
52 let sum: f64 = self.values.iter().sum();
53 if sum > 0.0 {
54 self.values /= sum;
55 }
56 }
57
58 pub fn get_cardinality(&self, var: &str) -> Option<usize> {
60 self.variables
61 .iter()
62 .position(|v| v == var)
63 .map(|idx| self.values.shape()[idx])
64 }
65}
66
67pub enum FactorOp {
69 Product,
71 Marginalize,
73 Divide,
75}
76
77impl Factor {
78 pub fn product(&self, other: &Factor) -> Result<Factor> {
82 let mut all_vars = self.variables.clone();
84 for v in &other.variables {
85 if !all_vars.contains(v) {
86 all_vars.push(v.clone());
87 }
88 }
89
90 let mut shape = Vec::new();
92 let mut self_mapping = Vec::new(); let mut other_mapping = Vec::new(); for var in &all_vars {
96 let self_idx_opt = self.variables.iter().position(|v| v == var);
98 let other_idx_opt = other.variables.iter().position(|v| v == var);
99
100 let cardinality = if let Some(self_idx) = self_idx_opt {
101 self_mapping.push(Some(self_idx));
102 self.values.shape()[self_idx]
103 } else if let Some(other_idx) = other_idx_opt {
104 self_mapping.push(None);
105 other.values.shape()[other_idx]
106 } else {
107 unreachable!("Variable must be in at least one factor");
108 };
109
110 if let Some(other_idx) = other_idx_opt {
111 other_mapping.push(Some(other_idx));
112 } else {
113 other_mapping.push(None);
114 }
115
116 shape.push(cardinality);
117 }
118
119 let mut result_values = ArrayD::zeros(shape.clone());
121 let total_size: usize = shape.iter().product();
122
123 for linear_idx in 0..total_size {
124 let mut assignment = Vec::new();
126 let mut temp_idx = linear_idx;
127 for &dim in shape.iter().rev() {
128 assignment.push(temp_idx % dim);
129 temp_idx /= dim;
130 }
131 assignment.reverse();
132
133 let self_idx: Vec<usize> = self_mapping
135 .iter()
136 .enumerate()
137 .filter_map(|(i, &opt)| opt.map(|_| assignment[i]))
138 .collect();
139
140 let other_idx: Vec<usize> = other_mapping
141 .iter()
142 .enumerate()
143 .filter_map(|(i, &opt)| opt.map(|_| assignment[i]))
144 .collect();
145
146 let self_val = if self_idx.len() == self.variables.len() {
148 self.values[self_idx.as_slice()]
149 } else {
150 1.0
151 };
152
153 let other_val = if other_idx.len() == other.variables.len() {
154 other.values[other_idx.as_slice()]
155 } else {
156 1.0
157 };
158
159 result_values[assignment.as_slice()] = self_val * other_val;
160 }
161
162 Ok(Factor {
163 name: format!("{}*{}", self.name, other.name),
164 variables: all_vars,
165 values: result_values,
166 })
167 }
168
169 pub fn marginalize_out(&self, var: &str) -> Result<Factor> {
173 use scirs2_core::ndarray::Axis;
174
175 let var_idx = self
177 .variables
178 .iter()
179 .position(|v| v == var)
180 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
181
182 let new_values = self.values.sum_axis(Axis(var_idx));
184
185 let new_vars: Vec<String> = self
187 .variables
188 .iter()
189 .filter(|v| *v != var)
190 .cloned()
191 .collect();
192
193 Ok(Factor {
194 name: format!("{}_marg", self.name),
195 variables: new_vars,
196 values: new_values,
197 })
198 }
199
200 pub fn marginalize_out_vars(&self, vars: &[String]) -> Result<Factor> {
202 let mut result = self.clone();
203 for var in vars {
204 result = result.marginalize_out(var)?;
205 }
206 Ok(result)
207 }
208
209 pub fn marginalize_out_all_except(&self, keep_vars: &[String]) -> Result<Factor> {
213 let vars_to_remove: Vec<String> = self
214 .variables
215 .iter()
216 .filter(|v| !keep_vars.contains(v))
217 .cloned()
218 .collect();
219
220 self.marginalize_out_vars(&vars_to_remove)
221 }
222
223 pub fn maximize_out(&self, var: &str) -> Result<Factor> {
227 use scirs2_core::ndarray::Axis;
228
229 let var_idx = self
231 .variables
232 .iter()
233 .position(|v| v == var)
234 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
235
236 let new_values = self.values.map_axis(Axis(var_idx), |view| {
238 view.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
239 });
240
241 let new_vars: Vec<String> = self
243 .variables
244 .iter()
245 .filter(|v| *v != var)
246 .cloned()
247 .collect();
248
249 Ok(Factor {
250 name: format!("{}_max", self.name),
251 variables: new_vars,
252 values: new_values,
253 })
254 }
255
256 pub fn maximize_out_vars(&self, vars: &[String]) -> Result<Factor> {
258 let mut result = self.clone();
259 for var in vars {
260 result = result.maximize_out(var)?;
261 }
262 Ok(result)
263 }
264
265 pub fn divide(&self, other: &Factor) -> Result<Factor> {
269 if self.variables != other.variables {
271 return Err(PgmError::InvalidDistribution(
272 "Cannot divide factors with different variables".to_string(),
273 ));
274 }
275
276 let result_values = &self.values
278 / &other
279 .values
280 .mapv(|x| if x.abs() < 1e-10 { 1e-10 } else { x });
281
282 Ok(Factor {
283 name: format!("{}/{}", self.name, other.name),
284 variables: self.variables.clone(),
285 values: result_values,
286 })
287 }
288
289 pub fn reduce(&self, var: &str, value: usize) -> Result<Factor> {
291 use scirs2_core::ndarray::Axis;
292
293 let var_idx = self
294 .variables
295 .iter()
296 .position(|v| v == var)
297 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
298
299 if value >= self.values.shape()[var_idx] {
301 return Err(PgmError::InvalidDistribution(format!(
302 "Value {} out of bounds for variable {} with cardinality {}",
303 value,
304 var,
305 self.values.shape()[var_idx]
306 )));
307 }
308
309 let new_values = self.values.index_axis(Axis(var_idx), value).to_owned();
311
312 let new_vars: Vec<String> = self
314 .variables
315 .iter()
316 .filter(|v| *v != var)
317 .cloned()
318 .collect();
319
320 Ok(Factor {
321 name: format!("{}_reduced", self.name),
322 variables: new_vars,
323 values: new_values,
324 })
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use scirs2_core::ndarray::Array;
332
333 #[test]
334 fn test_factor_creation() {
335 let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
336 .unwrap()
337 .into_dyn();
338 let factor = Factor::new(
339 "f1".to_string(),
340 vec!["x".to_string(), "y".to_string()],
341 values,
342 )
343 .unwrap();
344
345 assert_eq!(factor.variables.len(), 2);
346 assert_eq!(factor.values.ndim(), 2);
347 }
348
349 #[test]
350 fn test_factor_normalize() {
351 let values = Array::from_shape_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0])
352 .unwrap()
353 .into_dyn();
354 let mut factor = Factor::new(
355 "f1".to_string(),
356 vec!["x".to_string(), "y".to_string()],
357 values,
358 )
359 .unwrap();
360
361 factor.normalize();
362 let sum: f64 = factor.values.iter().sum();
363 assert!((sum - 1.0).abs() < 1e-10);
364 }
365
366 #[test]
367 fn test_uniform_factor() {
368 let factor = Factor::uniform("f1".to_string(), vec!["x".to_string()], 3);
369 assert_eq!(factor.values.len(), 3);
370 let sum: f64 = factor.values.iter().sum();
371 assert!((sum - 1.0).abs() < 1e-10);
372 }
373
374 #[test]
375 fn test_factor_product() {
376 let f1_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
378 .unwrap()
379 .into_dyn();
380 let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], f1_values).unwrap();
381
382 let f2_values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
383 .unwrap()
384 .into_dyn();
385 let f2 = Factor::new("f2".to_string(), vec!["y".to_string()], f2_values).unwrap();
386
387 let product = f1.product(&f2).unwrap();
388 assert_eq!(product.variables.len(), 2);
389 assert_eq!(product.values.shape(), &[2, 2]);
390
391 let expected = 0.6 * 0.7 + 0.6 * 0.3 + 0.4 * 0.7 + 0.4 * 0.3;
393 let actual: f64 = product.values.iter().sum();
394 assert!((actual - expected).abs() < 1e-10);
395 }
396
397 #[test]
398 fn test_factor_marginalize() {
399 let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
401 .unwrap()
402 .into_dyn();
403 let factor = Factor::new(
404 "f1".to_string(),
405 vec!["x".to_string(), "y".to_string()],
406 values,
407 )
408 .unwrap();
409
410 let marginal = factor.marginalize_out("y").unwrap();
411 assert_eq!(marginal.variables.len(), 1);
412 assert_eq!(marginal.variables[0], "x");
413 assert_eq!(marginal.values.shape(), &[2]);
414
415 assert!((marginal.values[[0]] - 0.3).abs() < 1e-10);
417 assert!((marginal.values[[1]] - 0.7).abs() < 1e-10);
418 }
419
420 #[test]
421 fn test_factor_divide() {
422 let values1 = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
423 .unwrap()
424 .into_dyn();
425 let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], values1).unwrap();
426
427 let values2 = Array::from_shape_vec(vec![2], vec![0.3, 0.2])
428 .unwrap()
429 .into_dyn();
430 let f2 = Factor::new("f2".to_string(), vec!["x".to_string()], values2).unwrap();
431
432 let result = f1.divide(&f2).unwrap();
433 assert_eq!(result.variables.len(), 1);
434
435 assert!((result.values[[0]] - 2.0).abs() < 1e-10);
437 assert!((result.values[[1]] - 2.0).abs() < 1e-10);
438 }
439
440 #[test]
441 fn test_factor_reduce() {
442 let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
444 .unwrap()
445 .into_dyn();
446 let factor = Factor::new(
447 "f1".to_string(),
448 vec!["x".to_string(), "y".to_string()],
449 values,
450 )
451 .unwrap();
452
453 let reduced = factor.reduce("y", 1).unwrap();
454 assert_eq!(reduced.variables.len(), 1);
455 assert_eq!(reduced.variables[0], "x");
456
457 assert!((reduced.values[[0]] - 0.2).abs() < 1e-10);
459 assert!((reduced.values[[1]] - 0.4).abs() < 1e-10);
460 }
461
462 #[test]
463 fn test_factor_product_with_shared_vars() {
464 let f1_values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
466 .unwrap()
467 .into_dyn();
468 let f1 = Factor::new(
469 "f1".to_string(),
470 vec!["x".to_string(), "y".to_string()],
471 f1_values,
472 )
473 .unwrap();
474
475 let f2_values = Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.5, 0.5])
476 .unwrap()
477 .into_dyn();
478 let f2 = Factor::new(
479 "f2".to_string(),
480 vec!["y".to_string(), "z".to_string()],
481 f2_values,
482 )
483 .unwrap();
484
485 let product = f1.product(&f2).unwrap();
486 assert_eq!(product.variables.len(), 3);
487 assert!(product.variables.contains(&"x".to_string()));
488 assert!(product.variables.contains(&"y".to_string()));
489 assert!(product.variables.contains(&"z".to_string()));
490 }
491
492 #[test]
493 fn test_factor_maximize() {
494 let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
496 .unwrap()
497 .into_dyn();
498 let factor = Factor::new(
499 "f1".to_string(),
500 vec!["x".to_string(), "y".to_string()],
501 values,
502 )
503 .unwrap();
504
505 let maximized = factor.maximize_out("y").unwrap();
506 assert_eq!(maximized.variables.len(), 1);
507 assert_eq!(maximized.variables[0], "x");
508 assert_eq!(maximized.values.shape(), &[2]);
509
510 assert!((maximized.values[[0]] - 0.2).abs() < 1e-10);
512 assert!((maximized.values[[1]] - 0.4).abs() < 1e-10);
513 }
514
515 #[test]
516 fn test_factor_maximize_multiple() {
517 let values =
519 Array::from_shape_vec(vec![2, 2, 2], vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
520 .unwrap()
521 .into_dyn();
522 let factor = Factor::new(
523 "f1".to_string(),
524 vec!["x".to_string(), "y".to_string(), "z".to_string()],
525 values,
526 )
527 .unwrap();
528
529 let maximized = factor
530 .maximize_out_vars(&["y".to_string(), "z".to_string()])
531 .unwrap();
532 assert_eq!(maximized.variables.len(), 1);
533 assert_eq!(maximized.variables[0], "x");
534 }
535}