1use scirs2_core::ndarray::{ArrayD, IxDyn};
14use std::collections::{HashMap, HashSet};
15
16use crate::error::{PgmError, Result};
17use crate::{Factor, FactorGraph, VariableElimination};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum NodeType {
22 Chance,
24 Decision,
26 Utility,
28}
29
30#[derive(Debug, Clone)]
32pub struct InfluenceNode {
33 pub name: String,
35 pub node_type: NodeType,
37 pub cardinality: usize,
39 pub parents: Vec<String>,
41}
42
43#[derive(Debug, Clone)]
63pub struct InfluenceDiagram {
64 nodes: HashMap<String, InfluenceNode>,
66 cpts: HashMap<String, ArrayD<f64>>,
68 utilities: HashMap<String, ArrayD<f64>>,
70 decision_order: Vec<String>,
72}
73
74impl Default for InfluenceDiagram {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl InfluenceDiagram {
81 pub fn new() -> Self {
83 Self {
84 nodes: HashMap::new(),
85 cpts: HashMap::new(),
86 utilities: HashMap::new(),
87 decision_order: Vec::new(),
88 }
89 }
90
91 pub fn add_chance_node(
93 &mut self,
94 name: String,
95 cardinality: usize,
96 parents: Vec<String>,
97 ) -> &mut Self {
98 self.nodes.insert(
99 name.clone(),
100 InfluenceNode {
101 name,
102 node_type: NodeType::Chance,
103 cardinality,
104 parents,
105 },
106 );
107 self
108 }
109
110 pub fn add_decision_node(
112 &mut self,
113 name: String,
114 cardinality: usize,
115 parents: Vec<String>,
116 ) -> &mut Self {
117 let node_name = name.clone();
118 self.nodes.insert(
119 name.clone(),
120 InfluenceNode {
121 name,
122 node_type: NodeType::Decision,
123 cardinality,
124 parents,
125 },
126 );
127 self.decision_order.push(node_name);
128 self
129 }
130
131 pub fn add_utility_node(&mut self, name: String, parents: Vec<String>) -> &mut Self {
133 self.nodes.insert(
134 name.clone(),
135 InfluenceNode {
136 name,
137 node_type: NodeType::Utility,
138 cardinality: 1, parents,
140 },
141 );
142 self
143 }
144
145 pub fn set_cpt(&mut self, node: &str, cpt: ArrayD<f64>) -> Result<&mut Self> {
147 if let Some(n) = self.nodes.get(node) {
148 if n.node_type != NodeType::Chance {
149 return Err(PgmError::InvalidDistribution(format!(
150 "Node {} is not a chance node",
151 node
152 )));
153 }
154 } else {
155 return Err(PgmError::VariableNotFound(node.to_string()));
156 }
157 self.cpts.insert(node.to_string(), cpt);
158 Ok(self)
159 }
160
161 pub fn set_utility(&mut self, node: &str, utility: ArrayD<f64>) -> Result<&mut Self> {
163 if let Some(n) = self.nodes.get(node) {
164 if n.node_type != NodeType::Utility {
165 return Err(PgmError::InvalidDistribution(format!(
166 "Node {} is not a utility node",
167 node
168 )));
169 }
170 } else {
171 return Err(PgmError::VariableNotFound(node.to_string()));
172 }
173 self.utilities.insert(node.to_string(), utility);
174 Ok(self)
175 }
176
177 pub fn set_decision_order(&mut self, order: Vec<String>) -> &mut Self {
179 self.decision_order = order;
180 self
181 }
182
183 pub fn chance_nodes(&self) -> Vec<&InfluenceNode> {
185 self.nodes
186 .values()
187 .filter(|n| n.node_type == NodeType::Chance)
188 .collect()
189 }
190
191 pub fn decision_nodes(&self) -> Vec<&InfluenceNode> {
193 self.nodes
194 .values()
195 .filter(|n| n.node_type == NodeType::Decision)
196 .collect()
197 }
198
199 pub fn utility_nodes(&self) -> Vec<&InfluenceNode> {
201 self.nodes
202 .values()
203 .filter(|n| n.node_type == NodeType::Utility)
204 .collect()
205 }
206
207 pub fn get_node(&self, name: &str) -> Option<&InfluenceNode> {
209 self.nodes.get(name)
210 }
211
212 pub fn to_factor_graph(&self) -> Result<FactorGraph> {
216 let mut graph = FactorGraph::new();
217
218 for (name, node) in &self.nodes {
220 if node.node_type != NodeType::Utility {
221 graph.add_variable_with_card(
222 name.clone(),
223 format!("{:?}", node.node_type),
224 node.cardinality,
225 );
226 }
227 }
228
229 for (name, cpt) in &self.cpts {
231 if let Some(node) = self.nodes.get(name) {
232 let mut vars = node.parents.clone();
233 vars.push(name.clone());
234
235 let factor = Factor::new(format!("P({})", name), vars, cpt.clone())?;
236 graph.add_factor(factor)?;
237 }
238 }
239
240 for (name, node) in &self.nodes {
242 if node.node_type == NodeType::Decision {
243 let uniform =
244 ArrayD::from_elem(IxDyn(&[node.cardinality]), 1.0 / node.cardinality as f64);
245 let factor = Factor::new(format!("U({})", name), vec![name.clone()], uniform)?;
246 graph.add_factor(factor)?;
247 }
248 }
249
250 Ok(graph)
251 }
252
253 pub fn expected_utility(&self, policy: &HashMap<String, usize>) -> Result<f64> {
257 let graph = self.to_factor_graph()?;
259
260 let ve = VariableElimination::default();
262
263 let mut total_utility = 0.0;
265
266 for (utility_name, utility_table) in &self.utilities {
267 if let Some(node) = self.nodes.get(utility_name) {
268 let parent_cardinalities: Vec<usize> = node
270 .parents
271 .iter()
272 .filter_map(|p| self.nodes.get(p).map(|n| n.cardinality))
273 .collect();
274
275 if parent_cardinalities.is_empty() {
276 total_utility += utility_table.iter().next().copied().unwrap_or(0.0);
278 continue;
279 }
280
281 let total_size: usize = parent_cardinalities.iter().product();
283
284 for flat_idx in 0..total_size {
285 let mut indices = vec![0; parent_cardinalities.len()];
287 let mut remaining = flat_idx;
288 for i in (0..parent_cardinalities.len()).rev() {
289 indices[i] = remaining % parent_cardinalities[i];
290 remaining /= parent_cardinalities[i];
291 }
292
293 let utility_val = utility_table[indices.as_slice()];
295
296 let mut prob = 1.0;
298 for (i, parent) in node.parents.iter().enumerate() {
299 if let Some(parent_node) = self.nodes.get(parent) {
300 match parent_node.node_type {
301 NodeType::Decision => {
302 if let Some(&policy_val) = policy.get(parent) {
304 if policy_val != indices[i] {
305 prob = 0.0;
306 break;
307 }
308 }
309 }
310 NodeType::Chance => {
311 if let Ok(marginal) = ve.marginalize(&graph, parent) {
313 if indices[i] < marginal.len() {
314 prob *= marginal[indices[i]];
315 }
316 }
317 }
318 NodeType::Utility => {}
319 }
320 }
321 }
322
323 total_utility += prob * utility_val;
324 }
325 }
326 }
327
328 Ok(total_utility)
329 }
330
331 pub fn optimal_policy(&self) -> Result<(HashMap<String, usize>, f64)> {
335 let decisions: Vec<_> = self.decision_nodes();
336
337 if decisions.is_empty() {
338 return Ok((HashMap::new(), self.expected_utility(&HashMap::new())?));
339 }
340
341 let mut best_policy = HashMap::new();
343 let mut best_utility = f64::NEG_INFINITY;
344
345 let cardinalities: Vec<usize> = decisions.iter().map(|d| d.cardinality).collect();
346 let total_policies: usize = cardinalities.iter().product();
347
348 for policy_idx in 0..total_policies {
349 let mut policy = HashMap::new();
351 let mut remaining = policy_idx;
352
353 for (i, decision) in decisions.iter().enumerate() {
354 let value = remaining % cardinalities[i];
355 remaining /= cardinalities[i];
356 policy.insert(decision.name.clone(), value);
357 }
358
359 let utility = self.expected_utility(&policy)?;
361
362 if utility > best_utility {
363 best_utility = utility;
364 best_policy = policy;
365 }
366 }
367
368 Ok((best_policy, best_utility))
369 }
370
371 pub fn value_of_perfect_information(&self, node: &str) -> Result<f64> {
376 if let Some(n) = self.nodes.get(node) {
378 if n.node_type != NodeType::Chance {
379 return Err(PgmError::InvalidDistribution(format!(
380 "Node {} is not a chance node",
381 node
382 )));
383 }
384 } else {
385 return Err(PgmError::VariableNotFound(node.to_string()));
386 }
387
388 let (_, base_utility) = self.optimal_policy()?;
390
391 let node_card = self
394 .nodes
395 .get(node)
396 .expect("node must exist in influence diagram nodes")
397 .cardinality;
398
399 let graph = self.to_factor_graph()?;
401 let ve = VariableElimination::default();
402 let marginal = ve.marginalize(&graph, node)?;
403
404 let mut expected_with_info = 0.0;
405
406 for value in 0..node_card {
407 let prob = if value < marginal.len() {
410 marginal[value]
411 } else {
412 0.0
413 };
414
415 expected_with_info += prob * base_utility;
417 }
418
419 Ok((expected_with_info - base_utility).max(0.0))
420 }
421
422 pub fn information_parents(&self, decision: &str) -> Vec<String> {
426 if let Some(node) = self.nodes.get(decision) {
427 if node.node_type == NodeType::Decision {
428 return node.parents.clone();
429 }
430 }
431 Vec::new()
432 }
433
434 pub fn is_well_formed(&self) -> bool {
441 for (name, node) in &self.nodes {
443 if node.node_type == NodeType::Utility {
444 for other in self.nodes.values() {
445 if other.parents.contains(name) {
446 return false;
447 }
448 }
449 }
450 }
451
452 let mut visited = HashSet::new();
454 let mut rec_stack = HashSet::new();
455
456 for name in self.nodes.keys() {
457 if !visited.contains(name) && self.has_cycle(name, &mut visited, &mut rec_stack) {
458 return false;
459 }
460 }
461
462 true
463 }
464
465 fn has_cycle(
467 &self,
468 node: &str,
469 visited: &mut HashSet<String>,
470 rec_stack: &mut HashSet<String>,
471 ) -> bool {
472 visited.insert(node.to_string());
473 rec_stack.insert(node.to_string());
474
475 if let Some(n) = self.nodes.get(node) {
476 for parent in &n.parents {
477 if !visited.contains(parent) {
478 if self.has_cycle(parent, visited, rec_stack) {
479 return true;
480 }
481 } else if rec_stack.contains(parent) {
482 return true;
483 }
484 }
485 }
486
487 rec_stack.remove(node);
488 false
489 }
490
491 pub fn num_nodes(&self) -> usize {
493 self.nodes.len()
494 }
495
496 pub fn num_decisions(&self) -> usize {
498 self.decision_nodes().len()
499 }
500
501 pub fn num_utilities(&self) -> usize {
503 self.utility_nodes().len()
504 }
505}
506
507pub struct InfluenceDiagramBuilder {
509 diagram: InfluenceDiagram,
510}
511
512impl Default for InfluenceDiagramBuilder {
513 fn default() -> Self {
514 Self::new()
515 }
516}
517
518impl InfluenceDiagramBuilder {
519 pub fn new() -> Self {
521 Self {
522 diagram: InfluenceDiagram::new(),
523 }
524 }
525
526 pub fn chance_node(mut self, name: String, cardinality: usize, parents: Vec<String>) -> Self {
528 self.diagram.add_chance_node(name, cardinality, parents);
529 self
530 }
531
532 pub fn decision_node(mut self, name: String, cardinality: usize, parents: Vec<String>) -> Self {
534 self.diagram.add_decision_node(name, cardinality, parents);
535 self
536 }
537
538 pub fn utility_node(mut self, name: String, parents: Vec<String>) -> Self {
540 self.diagram.add_utility_node(name, parents);
541 self
542 }
543
544 pub fn cpt(mut self, node: &str, cpt: ArrayD<f64>) -> Result<Self> {
546 self.diagram.set_cpt(node, cpt)?;
547 Ok(self)
548 }
549
550 pub fn utility(mut self, node: &str, utility: ArrayD<f64>) -> Result<Self> {
552 self.diagram.set_utility(node, utility)?;
553 Ok(self)
554 }
555
556 pub fn build(self) -> InfluenceDiagram {
558 self.diagram
559 }
560}
561
562#[derive(Debug, Clone)]
564pub struct MultiAttributeUtility {
565 utilities: Vec<(String, f64)>, }
568
569impl Default for MultiAttributeUtility {
570 fn default() -> Self {
571 Self::new()
572 }
573}
574
575impl MultiAttributeUtility {
576 pub fn new() -> Self {
578 Self {
579 utilities: Vec::new(),
580 }
581 }
582
583 pub fn add_utility(&mut self, name: String, weight: f64) -> &mut Self {
585 self.utilities.push((name, weight));
586 self
587 }
588
589 pub fn combine(&self, values: &HashMap<String, f64>) -> f64 {
591 let mut total = 0.0;
592
593 for (name, weight) in &self.utilities {
594 if let Some(&value) = values.get(name) {
595 total += weight * value;
596 }
597 }
598
599 total
600 }
601
602 pub fn weights(&self) -> HashMap<String, f64> {
604 self.utilities.iter().cloned().collect()
605 }
606
607 pub fn normalize_weights(&mut self) {
609 let total: f64 = self.utilities.iter().map(|(_, w)| w).sum();
610 if total > 0.0 {
611 for (_, w) in &mut self.utilities {
612 *w /= total;
613 }
614 }
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_influence_diagram_creation() {
624 let mut id = InfluenceDiagram::new();
625 id.add_chance_node("weather".to_string(), 2, vec![]);
626 id.add_decision_node("umbrella".to_string(), 2, vec!["weather".to_string()]);
627 id.add_utility_node(
628 "comfort".to_string(),
629 vec!["weather".to_string(), "umbrella".to_string()],
630 );
631
632 assert_eq!(id.num_nodes(), 3);
633 assert_eq!(id.num_decisions(), 1);
634 assert_eq!(id.num_utilities(), 1);
635 }
636
637 #[test]
638 fn test_node_types() {
639 let mut id = InfluenceDiagram::new();
640 id.add_chance_node("c".to_string(), 2, vec![]);
641 id.add_decision_node("d".to_string(), 2, vec![]);
642 id.add_utility_node("u".to_string(), vec!["c".to_string(), "d".to_string()]);
643
644 assert_eq!(id.chance_nodes().len(), 1);
645 assert_eq!(id.decision_nodes().len(), 1);
646 assert_eq!(id.utility_nodes().len(), 1);
647 }
648
649 #[test]
650 fn test_set_cpt() {
651 let mut id = InfluenceDiagram::new();
652 id.add_chance_node("x".to_string(), 2, vec![]);
653
654 let cpt = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.3, 0.7]).expect("unwrap");
655 let result = id.set_cpt("x", cpt);
656 assert!(result.is_ok());
657 }
658
659 #[test]
660 fn test_set_cpt_invalid_node() {
661 let mut id = InfluenceDiagram::new();
662 id.add_decision_node("d".to_string(), 2, vec![]);
663
664 let cpt = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.3, 0.7]).expect("unwrap");
665 let result = id.set_cpt("d", cpt);
666 assert!(result.is_err());
667 }
668
669 #[test]
670 fn test_set_utility() {
671 let mut id = InfluenceDiagram::new();
672 id.add_chance_node("x".to_string(), 2, vec![]);
673 id.add_utility_node("u".to_string(), vec!["x".to_string()]);
674
675 let utility = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).expect("unwrap");
676 let result = id.set_utility("u", utility);
677 assert!(result.is_ok());
678 }
679
680 #[test]
681 fn test_to_factor_graph() {
682 let mut id = InfluenceDiagram::new();
683 id.add_chance_node("x".to_string(), 2, vec![]);
684 id.add_decision_node("d".to_string(), 2, vec![]);
685
686 let cpt = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.5, 0.5]).expect("unwrap");
687 id.set_cpt("x", cpt).expect("unwrap");
688
689 let graph = id.to_factor_graph().expect("unwrap");
690 assert_eq!(graph.num_variables(), 2);
691 }
692
693 #[test]
694 fn test_well_formed() {
695 let mut id = InfluenceDiagram::new();
696 id.add_chance_node("x".to_string(), 2, vec![]);
697 id.add_decision_node("d".to_string(), 2, vec!["x".to_string()]);
698 id.add_utility_node("u".to_string(), vec!["d".to_string()]);
699
700 assert!(id.is_well_formed());
701 }
702
703 #[test]
704 fn test_information_parents() {
705 let mut id = InfluenceDiagram::new();
706 id.add_chance_node("x".to_string(), 2, vec![]);
707 id.add_decision_node("d".to_string(), 2, vec!["x".to_string()]);
708
709 let parents = id.information_parents("d");
710 assert_eq!(parents, vec!["x".to_string()]);
711 }
712
713 #[test]
714 fn test_builder() {
715 let id = InfluenceDiagramBuilder::new()
716 .chance_node("x".to_string(), 2, vec![])
717 .decision_node("d".to_string(), 2, vec!["x".to_string()])
718 .utility_node("u".to_string(), vec!["x".to_string(), "d".to_string()])
719 .build();
720
721 assert_eq!(id.num_nodes(), 3);
722 }
723
724 #[test]
725 fn test_multi_attribute_utility() {
726 let mut maut = MultiAttributeUtility::new();
727 maut.add_utility("cost".to_string(), 0.4);
728 maut.add_utility("quality".to_string(), 0.6);
729
730 let mut values = HashMap::new();
731 values.insert("cost".to_string(), 10.0);
732 values.insert("quality".to_string(), 20.0);
733
734 let combined = maut.combine(&values);
735 assert!((combined - 16.0).abs() < 1e-6); }
737
738 #[test]
739 fn test_normalize_weights() {
740 let mut maut = MultiAttributeUtility::new();
741 maut.add_utility("a".to_string(), 2.0);
742 maut.add_utility("b".to_string(), 3.0);
743
744 maut.normalize_weights();
745
746 let weights = maut.weights();
747 let total: f64 = weights.values().sum();
748 assert!((total - 1.0).abs() < 1e-6);
749 }
750
751 #[test]
752 fn test_expected_utility_simple() {
753 let mut id = InfluenceDiagram::new();
754 id.add_decision_node("d".to_string(), 2, vec![]);
755 id.add_utility_node("u".to_string(), vec!["d".to_string()]);
756
757 let utility = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).expect("unwrap");
759 id.set_utility("u", utility).expect("unwrap");
760
761 let mut policy = HashMap::new();
762 policy.insert("d".to_string(), 1);
763
764 let eu = id.expected_utility(&policy).expect("unwrap");
765 assert!((eu - 20.0).abs() < 1e-6);
766 }
767
768 #[test]
769 fn test_optimal_policy_simple() {
770 let mut id = InfluenceDiagram::new();
771 id.add_decision_node("d".to_string(), 2, vec![]);
772 id.add_utility_node("u".to_string(), vec!["d".to_string()]);
773
774 let utility = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).expect("unwrap");
776 id.set_utility("u", utility).expect("unwrap");
777
778 let (policy, eu) = id.optimal_policy().expect("unwrap");
779 assert_eq!(policy.get("d"), Some(&1));
780 assert!((eu - 20.0).abs() < 1e-6);
781 }
782}