1use scirs2_core::ndarray::{ArrayD, IxDyn};
14use std::collections::{HashMap, HashSet};
15
16use crate::error::{PgmError, Result};
17use crate::{Factor, FactorGraph, VariableElimination};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum NodeType {
22 Chance,
24 Decision,
26 Utility,
28}
29
30#[derive(Debug, Clone)]
32pub struct InfluenceNode {
33 pub name: String,
35 pub node_type: NodeType,
37 pub cardinality: usize,
39 pub parents: Vec<String>,
41}
42
43#[derive(Debug, Clone)]
63pub struct InfluenceDiagram {
64 nodes: HashMap<String, InfluenceNode>,
66 cpts: HashMap<String, ArrayD<f64>>,
68 utilities: HashMap<String, ArrayD<f64>>,
70 decision_order: Vec<String>,
72}
73
74impl Default for InfluenceDiagram {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl InfluenceDiagram {
81 pub fn new() -> Self {
83 Self {
84 nodes: HashMap::new(),
85 cpts: HashMap::new(),
86 utilities: HashMap::new(),
87 decision_order: Vec::new(),
88 }
89 }
90
91 pub fn add_chance_node(
93 &mut self,
94 name: String,
95 cardinality: usize,
96 parents: Vec<String>,
97 ) -> &mut Self {
98 self.nodes.insert(
99 name.clone(),
100 InfluenceNode {
101 name,
102 node_type: NodeType::Chance,
103 cardinality,
104 parents,
105 },
106 );
107 self
108 }
109
110 pub fn add_decision_node(
112 &mut self,
113 name: String,
114 cardinality: usize,
115 parents: Vec<String>,
116 ) -> &mut Self {
117 let node_name = name.clone();
118 self.nodes.insert(
119 name.clone(),
120 InfluenceNode {
121 name,
122 node_type: NodeType::Decision,
123 cardinality,
124 parents,
125 },
126 );
127 self.decision_order.push(node_name);
128 self
129 }
130
131 pub fn add_utility_node(&mut self, name: String, parents: Vec<String>) -> &mut Self {
133 self.nodes.insert(
134 name.clone(),
135 InfluenceNode {
136 name,
137 node_type: NodeType::Utility,
138 cardinality: 1, parents,
140 },
141 );
142 self
143 }
144
145 pub fn set_cpt(&mut self, node: &str, cpt: ArrayD<f64>) -> Result<&mut Self> {
147 if let Some(n) = self.nodes.get(node) {
148 if n.node_type != NodeType::Chance {
149 return Err(PgmError::InvalidDistribution(format!(
150 "Node {} is not a chance node",
151 node
152 )));
153 }
154 } else {
155 return Err(PgmError::VariableNotFound(node.to_string()));
156 }
157 self.cpts.insert(node.to_string(), cpt);
158 Ok(self)
159 }
160
161 pub fn set_utility(&mut self, node: &str, utility: ArrayD<f64>) -> Result<&mut Self> {
163 if let Some(n) = self.nodes.get(node) {
164 if n.node_type != NodeType::Utility {
165 return Err(PgmError::InvalidDistribution(format!(
166 "Node {} is not a utility node",
167 node
168 )));
169 }
170 } else {
171 return Err(PgmError::VariableNotFound(node.to_string()));
172 }
173 self.utilities.insert(node.to_string(), utility);
174 Ok(self)
175 }
176
177 pub fn set_decision_order(&mut self, order: Vec<String>) -> &mut Self {
179 self.decision_order = order;
180 self
181 }
182
183 pub fn chance_nodes(&self) -> Vec<&InfluenceNode> {
185 self.nodes
186 .values()
187 .filter(|n| n.node_type == NodeType::Chance)
188 .collect()
189 }
190
191 pub fn decision_nodes(&self) -> Vec<&InfluenceNode> {
193 self.nodes
194 .values()
195 .filter(|n| n.node_type == NodeType::Decision)
196 .collect()
197 }
198
199 pub fn utility_nodes(&self) -> Vec<&InfluenceNode> {
201 self.nodes
202 .values()
203 .filter(|n| n.node_type == NodeType::Utility)
204 .collect()
205 }
206
207 pub fn get_node(&self, name: &str) -> Option<&InfluenceNode> {
209 self.nodes.get(name)
210 }
211
212 pub fn to_factor_graph(&self) -> Result<FactorGraph> {
216 let mut graph = FactorGraph::new();
217
218 for (name, node) in &self.nodes {
220 if node.node_type != NodeType::Utility {
221 graph.add_variable_with_card(
222 name.clone(),
223 format!("{:?}", node.node_type),
224 node.cardinality,
225 );
226 }
227 }
228
229 for (name, cpt) in &self.cpts {
231 if let Some(node) = self.nodes.get(name) {
232 let mut vars = node.parents.clone();
233 vars.push(name.clone());
234
235 let factor = Factor::new(format!("P({})", name), vars, cpt.clone())?;
236 graph.add_factor(factor)?;
237 }
238 }
239
240 for (name, node) in &self.nodes {
242 if node.node_type == NodeType::Decision {
243 let uniform =
244 ArrayD::from_elem(IxDyn(&[node.cardinality]), 1.0 / node.cardinality as f64);
245 let factor = Factor::new(format!("U({})", name), vec![name.clone()], uniform)?;
246 graph.add_factor(factor)?;
247 }
248 }
249
250 Ok(graph)
251 }
252
253 pub fn expected_utility(&self, policy: &HashMap<String, usize>) -> Result<f64> {
257 let graph = self.to_factor_graph()?;
259
260 let ve = VariableElimination::default();
262
263 let mut total_utility = 0.0;
265
266 for (utility_name, utility_table) in &self.utilities {
267 if let Some(node) = self.nodes.get(utility_name) {
268 let parent_cardinalities: Vec<usize> = node
270 .parents
271 .iter()
272 .filter_map(|p| self.nodes.get(p).map(|n| n.cardinality))
273 .collect();
274
275 if parent_cardinalities.is_empty() {
276 total_utility += utility_table.iter().next().copied().unwrap_or(0.0);
278 continue;
279 }
280
281 let total_size: usize = parent_cardinalities.iter().product();
283
284 for flat_idx in 0..total_size {
285 let mut indices = vec![0; parent_cardinalities.len()];
287 let mut remaining = flat_idx;
288 for i in (0..parent_cardinalities.len()).rev() {
289 indices[i] = remaining % parent_cardinalities[i];
290 remaining /= parent_cardinalities[i];
291 }
292
293 let utility_val = utility_table[indices.as_slice()];
295
296 let mut prob = 1.0;
298 for (i, parent) in node.parents.iter().enumerate() {
299 if let Some(parent_node) = self.nodes.get(parent) {
300 match parent_node.node_type {
301 NodeType::Decision => {
302 if let Some(&policy_val) = policy.get(parent) {
304 if policy_val != indices[i] {
305 prob = 0.0;
306 break;
307 }
308 }
309 }
310 NodeType::Chance => {
311 if let Ok(marginal) = ve.marginalize(&graph, parent) {
313 if indices[i] < marginal.len() {
314 prob *= marginal[indices[i]];
315 }
316 }
317 }
318 NodeType::Utility => {}
319 }
320 }
321 }
322
323 total_utility += prob * utility_val;
324 }
325 }
326 }
327
328 Ok(total_utility)
329 }
330
331 pub fn optimal_policy(&self) -> Result<(HashMap<String, usize>, f64)> {
335 let decisions: Vec<_> = self.decision_nodes();
336
337 if decisions.is_empty() {
338 return Ok((HashMap::new(), self.expected_utility(&HashMap::new())?));
339 }
340
341 let mut best_policy = HashMap::new();
343 let mut best_utility = f64::NEG_INFINITY;
344
345 let cardinalities: Vec<usize> = decisions.iter().map(|d| d.cardinality).collect();
346 let total_policies: usize = cardinalities.iter().product();
347
348 for policy_idx in 0..total_policies {
349 let mut policy = HashMap::new();
351 let mut remaining = policy_idx;
352
353 for (i, decision) in decisions.iter().enumerate() {
354 let value = remaining % cardinalities[i];
355 remaining /= cardinalities[i];
356 policy.insert(decision.name.clone(), value);
357 }
358
359 let utility = self.expected_utility(&policy)?;
361
362 if utility > best_utility {
363 best_utility = utility;
364 best_policy = policy;
365 }
366 }
367
368 Ok((best_policy, best_utility))
369 }
370
371 pub fn value_of_perfect_information(&self, node: &str) -> Result<f64> {
376 if let Some(n) = self.nodes.get(node) {
378 if n.node_type != NodeType::Chance {
379 return Err(PgmError::InvalidDistribution(format!(
380 "Node {} is not a chance node",
381 node
382 )));
383 }
384 } else {
385 return Err(PgmError::VariableNotFound(node.to_string()));
386 }
387
388 let (_, base_utility) = self.optimal_policy()?;
390
391 let node_card = self.nodes.get(node).unwrap().cardinality;
394
395 let graph = self.to_factor_graph()?;
397 let ve = VariableElimination::default();
398 let marginal = ve.marginalize(&graph, node)?;
399
400 let mut expected_with_info = 0.0;
401
402 for value in 0..node_card {
403 let prob = if value < marginal.len() {
406 marginal[value]
407 } else {
408 0.0
409 };
410
411 expected_with_info += prob * base_utility;
413 }
414
415 Ok((expected_with_info - base_utility).max(0.0))
416 }
417
418 pub fn information_parents(&self, decision: &str) -> Vec<String> {
422 if let Some(node) = self.nodes.get(decision) {
423 if node.node_type == NodeType::Decision {
424 return node.parents.clone();
425 }
426 }
427 Vec::new()
428 }
429
430 pub fn is_well_formed(&self) -> bool {
437 for (name, node) in &self.nodes {
439 if node.node_type == NodeType::Utility {
440 for other in self.nodes.values() {
441 if other.parents.contains(name) {
442 return false;
443 }
444 }
445 }
446 }
447
448 let mut visited = HashSet::new();
450 let mut rec_stack = HashSet::new();
451
452 for name in self.nodes.keys() {
453 if !visited.contains(name) && self.has_cycle(name, &mut visited, &mut rec_stack) {
454 return false;
455 }
456 }
457
458 true
459 }
460
461 fn has_cycle(
463 &self,
464 node: &str,
465 visited: &mut HashSet<String>,
466 rec_stack: &mut HashSet<String>,
467 ) -> bool {
468 visited.insert(node.to_string());
469 rec_stack.insert(node.to_string());
470
471 if let Some(n) = self.nodes.get(node) {
472 for parent in &n.parents {
473 if !visited.contains(parent) {
474 if self.has_cycle(parent, visited, rec_stack) {
475 return true;
476 }
477 } else if rec_stack.contains(parent) {
478 return true;
479 }
480 }
481 }
482
483 rec_stack.remove(node);
484 false
485 }
486
487 pub fn num_nodes(&self) -> usize {
489 self.nodes.len()
490 }
491
492 pub fn num_decisions(&self) -> usize {
494 self.decision_nodes().len()
495 }
496
497 pub fn num_utilities(&self) -> usize {
499 self.utility_nodes().len()
500 }
501}
502
503pub struct InfluenceDiagramBuilder {
505 diagram: InfluenceDiagram,
506}
507
508impl Default for InfluenceDiagramBuilder {
509 fn default() -> Self {
510 Self::new()
511 }
512}
513
514impl InfluenceDiagramBuilder {
515 pub fn new() -> Self {
517 Self {
518 diagram: InfluenceDiagram::new(),
519 }
520 }
521
522 pub fn chance_node(mut self, name: String, cardinality: usize, parents: Vec<String>) -> Self {
524 self.diagram.add_chance_node(name, cardinality, parents);
525 self
526 }
527
528 pub fn decision_node(mut self, name: String, cardinality: usize, parents: Vec<String>) -> Self {
530 self.diagram.add_decision_node(name, cardinality, parents);
531 self
532 }
533
534 pub fn utility_node(mut self, name: String, parents: Vec<String>) -> Self {
536 self.diagram.add_utility_node(name, parents);
537 self
538 }
539
540 pub fn cpt(mut self, node: &str, cpt: ArrayD<f64>) -> Result<Self> {
542 self.diagram.set_cpt(node, cpt)?;
543 Ok(self)
544 }
545
546 pub fn utility(mut self, node: &str, utility: ArrayD<f64>) -> Result<Self> {
548 self.diagram.set_utility(node, utility)?;
549 Ok(self)
550 }
551
552 pub fn build(self) -> InfluenceDiagram {
554 self.diagram
555 }
556}
557
558#[derive(Debug, Clone)]
560pub struct MultiAttributeUtility {
561 utilities: Vec<(String, f64)>, }
564
565impl Default for MultiAttributeUtility {
566 fn default() -> Self {
567 Self::new()
568 }
569}
570
571impl MultiAttributeUtility {
572 pub fn new() -> Self {
574 Self {
575 utilities: Vec::new(),
576 }
577 }
578
579 pub fn add_utility(&mut self, name: String, weight: f64) -> &mut Self {
581 self.utilities.push((name, weight));
582 self
583 }
584
585 pub fn combine(&self, values: &HashMap<String, f64>) -> f64 {
587 let mut total = 0.0;
588
589 for (name, weight) in &self.utilities {
590 if let Some(&value) = values.get(name) {
591 total += weight * value;
592 }
593 }
594
595 total
596 }
597
598 pub fn weights(&self) -> HashMap<String, f64> {
600 self.utilities.iter().cloned().collect()
601 }
602
603 pub fn normalize_weights(&mut self) {
605 let total: f64 = self.utilities.iter().map(|(_, w)| w).sum();
606 if total > 0.0 {
607 for (_, w) in &mut self.utilities {
608 *w /= total;
609 }
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
619 fn test_influence_diagram_creation() {
620 let mut id = InfluenceDiagram::new();
621 id.add_chance_node("weather".to_string(), 2, vec![]);
622 id.add_decision_node("umbrella".to_string(), 2, vec!["weather".to_string()]);
623 id.add_utility_node(
624 "comfort".to_string(),
625 vec!["weather".to_string(), "umbrella".to_string()],
626 );
627
628 assert_eq!(id.num_nodes(), 3);
629 assert_eq!(id.num_decisions(), 1);
630 assert_eq!(id.num_utilities(), 1);
631 }
632
633 #[test]
634 fn test_node_types() {
635 let mut id = InfluenceDiagram::new();
636 id.add_chance_node("c".to_string(), 2, vec![]);
637 id.add_decision_node("d".to_string(), 2, vec![]);
638 id.add_utility_node("u".to_string(), vec!["c".to_string(), "d".to_string()]);
639
640 assert_eq!(id.chance_nodes().len(), 1);
641 assert_eq!(id.decision_nodes().len(), 1);
642 assert_eq!(id.utility_nodes().len(), 1);
643 }
644
645 #[test]
646 fn test_set_cpt() {
647 let mut id = InfluenceDiagram::new();
648 id.add_chance_node("x".to_string(), 2, vec![]);
649
650 let cpt = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.3, 0.7]).unwrap();
651 let result = id.set_cpt("x", cpt);
652 assert!(result.is_ok());
653 }
654
655 #[test]
656 fn test_set_cpt_invalid_node() {
657 let mut id = InfluenceDiagram::new();
658 id.add_decision_node("d".to_string(), 2, vec![]);
659
660 let cpt = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.3, 0.7]).unwrap();
661 let result = id.set_cpt("d", cpt);
662 assert!(result.is_err());
663 }
664
665 #[test]
666 fn test_set_utility() {
667 let mut id = InfluenceDiagram::new();
668 id.add_chance_node("x".to_string(), 2, vec![]);
669 id.add_utility_node("u".to_string(), vec!["x".to_string()]);
670
671 let utility = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).unwrap();
672 let result = id.set_utility("u", utility);
673 assert!(result.is_ok());
674 }
675
676 #[test]
677 fn test_to_factor_graph() {
678 let mut id = InfluenceDiagram::new();
679 id.add_chance_node("x".to_string(), 2, vec![]);
680 id.add_decision_node("d".to_string(), 2, vec![]);
681
682 let cpt = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.5, 0.5]).unwrap();
683 id.set_cpt("x", cpt).unwrap();
684
685 let graph = id.to_factor_graph().unwrap();
686 assert_eq!(graph.num_variables(), 2);
687 }
688
689 #[test]
690 fn test_well_formed() {
691 let mut id = InfluenceDiagram::new();
692 id.add_chance_node("x".to_string(), 2, vec![]);
693 id.add_decision_node("d".to_string(), 2, vec!["x".to_string()]);
694 id.add_utility_node("u".to_string(), vec!["d".to_string()]);
695
696 assert!(id.is_well_formed());
697 }
698
699 #[test]
700 fn test_information_parents() {
701 let mut id = InfluenceDiagram::new();
702 id.add_chance_node("x".to_string(), 2, vec![]);
703 id.add_decision_node("d".to_string(), 2, vec!["x".to_string()]);
704
705 let parents = id.information_parents("d");
706 assert_eq!(parents, vec!["x".to_string()]);
707 }
708
709 #[test]
710 fn test_builder() {
711 let id = InfluenceDiagramBuilder::new()
712 .chance_node("x".to_string(), 2, vec![])
713 .decision_node("d".to_string(), 2, vec!["x".to_string()])
714 .utility_node("u".to_string(), vec!["x".to_string(), "d".to_string()])
715 .build();
716
717 assert_eq!(id.num_nodes(), 3);
718 }
719
720 #[test]
721 fn test_multi_attribute_utility() {
722 let mut maut = MultiAttributeUtility::new();
723 maut.add_utility("cost".to_string(), 0.4);
724 maut.add_utility("quality".to_string(), 0.6);
725
726 let mut values = HashMap::new();
727 values.insert("cost".to_string(), 10.0);
728 values.insert("quality".to_string(), 20.0);
729
730 let combined = maut.combine(&values);
731 assert!((combined - 16.0).abs() < 1e-6); }
733
734 #[test]
735 fn test_normalize_weights() {
736 let mut maut = MultiAttributeUtility::new();
737 maut.add_utility("a".to_string(), 2.0);
738 maut.add_utility("b".to_string(), 3.0);
739
740 maut.normalize_weights();
741
742 let weights = maut.weights();
743 let total: f64 = weights.values().sum();
744 assert!((total - 1.0).abs() < 1e-6);
745 }
746
747 #[test]
748 fn test_expected_utility_simple() {
749 let mut id = InfluenceDiagram::new();
750 id.add_decision_node("d".to_string(), 2, vec![]);
751 id.add_utility_node("u".to_string(), vec!["d".to_string()]);
752
753 let utility = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).unwrap();
755 id.set_utility("u", utility).unwrap();
756
757 let mut policy = HashMap::new();
758 policy.insert("d".to_string(), 1);
759
760 let eu = id.expected_utility(&policy).unwrap();
761 assert!((eu - 20.0).abs() < 1e-6);
762 }
763
764 #[test]
765 fn test_optimal_policy_simple() {
766 let mut id = InfluenceDiagram::new();
767 id.add_decision_node("d".to_string(), 2, vec![]);
768 id.add_utility_node("u".to_string(), vec!["d".to_string()]);
769
770 let utility = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).unwrap();
772 id.set_utility("u", utility).unwrap();
773
774 let (policy, eu) = id.optimal_policy().unwrap();
775 assert_eq!(policy.get("d"), Some(&1));
776 assert!((eu - 20.0).abs() < 1e-6);
777 }
778}