1use super::TLExpr;
17use std::collections::HashMap;
18
19#[derive(Debug, Clone, Copy, PartialEq)]
24pub struct ProbabilityInterval {
25 pub lower: f64,
27 pub upper: f64,
29}
30
31impl ProbabilityInterval {
32 pub fn new(lower: f64, upper: f64) -> Option<Self> {
36 if lower < 0.0 || upper > 1.0 || lower > upper {
37 None
38 } else {
39 Some(Self { lower, upper })
40 }
41 }
42
43 pub fn precise(prob: f64) -> Option<Self> {
45 Self::new(prob, prob)
46 }
47
48 pub fn vacuous() -> Self {
50 Self {
51 lower: 0.0,
52 upper: 1.0,
53 }
54 }
55
56 pub fn width(&self) -> f64 {
58 self.upper - self.lower
59 }
60
61 pub fn is_precise(&self) -> bool {
63 (self.upper - self.lower).abs() < 1e-10
64 }
65
66 pub fn is_vacuous(&self) -> bool {
68 self.lower == 0.0 && self.upper == 1.0
69 }
70
71 pub fn complement(&self) -> Self {
73 Self {
74 lower: 1.0 - self.upper,
75 upper: 1.0 - self.lower,
76 }
77 }
78
79 pub fn and(&self, other: &Self) -> Self {
83 let lower = (self.lower + other.lower - 1.0).max(0.0);
84 let upper = self.upper.min(other.upper);
85 Self { lower, upper }
86 }
87
88 pub fn or(&self, other: &Self) -> Self {
92 let lower = self.lower.max(other.lower);
93 let upper = (self.upper + other.upper).min(1.0);
94 Self { lower, upper }
95 }
96
97 pub fn implies(&self, other: &Self) -> Self {
101 self.complement().or(other)
102 }
103
104 pub fn conditional(&self, joint: &Self) -> Option<Self> {
109 if self.upper == 0.0 {
110 None
112 } else if self.lower == 0.0 {
113 Some(Self {
115 lower: 0.0,
116 upper: joint.upper / self.lower.max(1e-10),
117 })
118 } else {
119 Some(Self {
120 lower: joint.lower / self.upper,
121 upper: joint.upper / self.lower,
122 })
123 }
124 }
125
126 pub fn intersect(&self, other: &Self) -> Option<Self> {
130 let lower = self.lower.max(other.lower);
131 let upper = self.upper.min(other.upper);
132 if lower <= upper {
133 Some(Self { lower, upper })
134 } else {
135 None
136 }
137 }
138
139 pub fn convex_combine(&self, other: &Self, weight: f64) -> Option<Self> {
143 if !(0.0..=1.0).contains(&weight) {
144 return None;
145 }
146 Some(Self {
147 lower: self.lower * weight + other.lower * (1.0 - weight),
148 upper: self.upper * weight + other.upper * (1.0 - weight),
149 })
150 }
151}
152
153#[derive(Debug, Clone)]
157pub struct CredalSet {
158 extreme_points: Vec<HashMap<String, f64>>,
160}
161
162impl CredalSet {
163 pub fn new(extreme_points: Vec<HashMap<String, f64>>) -> Self {
165 Self { extreme_points }
166 }
167
168 pub fn precise(distribution: HashMap<String, f64>) -> Self {
170 Self {
171 extreme_points: vec![distribution],
172 }
173 }
174
175 pub fn lower_prob(&self, event: &str) -> f64 {
177 self.extreme_points
178 .iter()
179 .filter_map(|dist| dist.get(event).copied())
180 .fold(f64::INFINITY, f64::min)
181 }
182
183 pub fn upper_prob(&self, event: &str) -> f64 {
185 self.extreme_points
186 .iter()
187 .filter_map(|dist| dist.get(event).copied())
188 .fold(f64::NEG_INFINITY, f64::max)
189 }
190
191 pub fn prob_interval(&self, event: &str) -> ProbabilityInterval {
193 ProbabilityInterval {
194 lower: self.lower_prob(event),
195 upper: self.upper_prob(event),
196 }
197 }
198
199 pub fn size(&self) -> usize {
201 self.extreme_points.len()
202 }
203
204 pub fn is_precise(&self) -> bool {
206 self.extreme_points.len() == 1
207 }
208}
209
210pub fn propagate_probabilities(
215 expr: &TLExpr,
216 prob_map: &HashMap<String, ProbabilityInterval>,
217) -> ProbabilityInterval {
218 match expr {
219 TLExpr::Pred { name, .. } => prob_map
220 .get(name)
221 .copied()
222 .unwrap_or_else(ProbabilityInterval::vacuous),
223
224 TLExpr::Constant(v) => {
225 if *v >= 1.0 {
226 ProbabilityInterval::precise(1.0).unwrap()
227 } else if *v <= 0.0 {
228 ProbabilityInterval::precise(0.0).unwrap()
229 } else {
230 ProbabilityInterval::vacuous()
231 }
232 }
233
234 TLExpr::And(left, right) => {
235 let left_prob = propagate_probabilities(left, prob_map);
236 let right_prob = propagate_probabilities(right, prob_map);
237 left_prob.and(&right_prob)
238 }
239
240 TLExpr::Or(left, right) => {
241 let left_prob = propagate_probabilities(left, prob_map);
242 let right_prob = propagate_probabilities(right, prob_map);
243 left_prob.or(&right_prob)
244 }
245
246 TLExpr::Not(inner) => {
247 let inner_prob = propagate_probabilities(inner, prob_map);
248 inner_prob.complement()
249 }
250
251 TLExpr::Imply(premise, conclusion) => {
252 let premise_prob = propagate_probabilities(premise, prob_map);
253 let conclusion_prob = propagate_probabilities(conclusion, prob_map);
254 premise_prob.implies(&conclusion_prob)
255 }
256
257 TLExpr::WeightedRule { weight, rule } => {
259 let rule_prob = propagate_probabilities(rule, prob_map);
260 ProbabilityInterval {
262 lower: rule_prob.lower * weight,
263 upper: rule_prob.upper * weight,
264 }
265 }
266
267 TLExpr::ProbabilisticChoice { alternatives } => {
269 let mut lower_sum = 0.0;
270 let mut upper_sum = 0.0;
271 let mut total_weight = 0.0;
272
273 for (prob, expr) in alternatives {
274 let expr_interval = propagate_probabilities(expr, prob_map);
275 lower_sum += prob * expr_interval.lower;
276 upper_sum += prob * expr_interval.upper;
277 total_weight += prob;
278 }
279
280 if total_weight > 0.0 && (total_weight - 1.0).abs() > 1e-10 {
282 lower_sum /= total_weight;
283 upper_sum /= total_weight;
284 }
285
286 ProbabilityInterval {
287 lower: lower_sum.clamp(0.0, 1.0),
288 upper: upper_sum.clamp(0.0, 1.0),
289 }
290 }
291
292 _ => ProbabilityInterval::vacuous(),
294 }
295}
296
297pub fn compute_tight_bounds(
302 expr: &TLExpr,
303 prob_map: &HashMap<String, ProbabilityInterval>,
304) -> ProbabilityInterval {
305 let mut current = propagate_probabilities(expr, prob_map);
307
308 for _ in 0..3 {
311 current = tighten_iteration(expr, prob_map, ¤t);
312 }
313
314 current
315}
316
317fn tighten_iteration(
318 expr: &TLExpr,
319 prob_map: &HashMap<String, ProbabilityInterval>,
320 current: &ProbabilityInterval,
321) -> ProbabilityInterval {
322 match expr {
323 TLExpr::And(left, right) => {
324 let left_prob = compute_tight_bounds(left, prob_map);
325 let right_prob = compute_tight_bounds(right, prob_map);
326
327 let mut result = left_prob.and(&right_prob);
329
330 if let Some(intersection) = result.intersect(current) {
332 result = intersection;
333 }
334
335 result
336 }
337
338 TLExpr::Or(left, right) => {
339 let left_prob = compute_tight_bounds(left, prob_map);
340 let right_prob = compute_tight_bounds(right, prob_map);
341
342 let mut result = left_prob.or(&right_prob);
343
344 if let Some(intersection) = result.intersect(current) {
345 result = intersection;
346 }
347
348 result
349 }
350
351 _ => propagate_probabilities(expr, prob_map),
352 }
353}
354
355pub fn extract_probabilistic_semantics(expr: &TLExpr) -> Vec<(f64, TLExpr)> {
359 let mut weighted_rules = Vec::new();
360 extract_weighted_rec(expr, &mut weighted_rules);
361 weighted_rules
362}
363
364fn extract_weighted_rec(expr: &TLExpr, result: &mut Vec<(f64, TLExpr)>) {
365 match expr {
366 TLExpr::WeightedRule { weight, rule } => {
367 result.push((*weight, (**rule).clone()));
368 extract_weighted_rec(rule, result);
369 }
370
371 TLExpr::ProbabilisticChoice { alternatives } => {
372 for (prob, expr) in alternatives {
373 result.push((*prob, expr.clone()));
374 extract_weighted_rec(expr, result);
375 }
376 }
377
378 TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
379 extract_weighted_rec(l, result);
380 extract_weighted_rec(r, result);
381 }
382
383 TLExpr::Not(e) => extract_weighted_rec(e, result),
384
385 _ => {}
386 }
387}
388
389pub fn mln_probability(
394 _expr: &TLExpr,
395 weights: &[(f64, TLExpr)],
396 evidence: &HashMap<String, bool>,
397) -> f64 {
398 let mut total_weight = 0.0;
400
401 for (weight, rule) in weights {
402 if evaluates_true(rule, evidence) {
403 total_weight += weight;
404 }
405 }
406
407 1.0 / (1.0 + (-total_weight).exp())
409}
410
411fn evaluates_true(expr: &TLExpr, evidence: &HashMap<String, bool>) -> bool {
413 match expr {
414 TLExpr::Pred { name, .. } => evidence.get(name).copied().unwrap_or(false),
415
416 TLExpr::And(l, r) => evaluates_true(l, evidence) && evaluates_true(r, evidence),
417
418 TLExpr::Or(l, r) => evaluates_true(l, evidence) || evaluates_true(r, evidence),
419
420 TLExpr::Not(e) => !evaluates_true(e, evidence),
421
422 TLExpr::Imply(l, r) => !evaluates_true(l, evidence) || evaluates_true(r, evidence),
423
424 TLExpr::Constant(v) => *v >= 1.0,
425
426 _ => false,
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_probability_interval_creation() {
436 let interval = ProbabilityInterval::new(0.3, 0.7).unwrap();
437 assert!((interval.lower - 0.3).abs() < 1e-10);
438 assert!((interval.upper - 0.7).abs() < 1e-10);
439 assert!((interval.width() - 0.4).abs() < 1e-10);
440
441 assert!(ProbabilityInterval::new(-0.1, 0.5).is_none());
443 assert!(ProbabilityInterval::new(0.8, 0.5).is_none());
444 assert!(ProbabilityInterval::new(0.5, 1.5).is_none());
445 }
446
447 #[test]
448 fn test_precise_probability() {
449 let precise = ProbabilityInterval::precise(0.5).unwrap();
450 assert!(precise.is_precise());
451 assert_eq!(precise.width(), 0.0);
452 }
453
454 #[test]
455 fn test_vacuous_interval() {
456 let vacuous = ProbabilityInterval::vacuous();
457 assert!(vacuous.is_vacuous());
458 assert_eq!(vacuous.width(), 1.0);
459 }
460
461 #[test]
462 fn test_complement() {
463 let interval = ProbabilityInterval::new(0.3, 0.7).unwrap();
464 let complement = interval.complement();
465 assert!((complement.lower - 0.3).abs() < 1e-10);
466 assert!((complement.upper - 0.7).abs() < 1e-10);
467 }
468
469 #[test]
470 fn test_frechet_and() {
471 let p_a = ProbabilityInterval::new(0.4, 0.6).unwrap();
472 let p_b = ProbabilityInterval::new(0.5, 0.8).unwrap();
473 let p_and = p_a.and(&p_b);
474
475 assert_eq!(p_and.lower, 0.0);
477 assert_eq!(p_and.upper, 0.6);
479 }
480
481 #[test]
482 fn test_frechet_or() {
483 let p_a = ProbabilityInterval::new(0.4, 0.6).unwrap();
484 let p_b = ProbabilityInterval::new(0.5, 0.8).unwrap();
485 let p_or = p_a.or(&p_b);
486
487 assert_eq!(p_or.lower, 0.5);
489 assert_eq!(p_or.upper, 1.0);
491 }
492
493 #[test]
494 fn test_implication_bounds() {
495 let p_a = ProbabilityInterval::new(0.3, 0.5).unwrap();
496 let p_b = ProbabilityInterval::new(0.6, 0.9).unwrap();
497 let p_implies = p_a.implies(&p_b);
498
499 let not_a = p_a.complement();
501 let expected = not_a.or(&p_b);
502
503 assert_eq!(p_implies.lower, expected.lower);
504 assert_eq!(p_implies.upper, expected.upper);
505 }
506
507 #[test]
508 fn test_conditional_probability() {
509 let p_a = ProbabilityInterval::new(0.4, 0.6).unwrap();
510 let p_a_and_b = ProbabilityInterval::new(0.2, 0.3).unwrap();
511
512 let p_b_given_a = p_a.conditional(&p_a_and_b).unwrap();
513
514 assert!((p_b_given_a.lower - 0.333).abs() < 0.01);
518 assert!((p_b_given_a.upper - 0.75).abs() < 1e-10);
519 }
520
521 #[test]
522 fn test_interval_intersection() {
523 let i1 = ProbabilityInterval::new(0.2, 0.7).unwrap();
524 let i2 = ProbabilityInterval::new(0.5, 0.9).unwrap();
525
526 let intersection = i1.intersect(&i2).unwrap();
527 assert_eq!(intersection.lower, 0.5);
528 assert_eq!(intersection.upper, 0.7);
529
530 let i3 = ProbabilityInterval::new(0.1, 0.3).unwrap();
532 let i4 = ProbabilityInterval::new(0.6, 0.9).unwrap();
533 assert!(i3.intersect(&i4).is_none());
534 }
535
536 #[test]
537 fn test_convex_combination() {
538 let i1 = ProbabilityInterval::new(0.2, 0.4).unwrap();
539 let i2 = ProbabilityInterval::new(0.6, 0.8).unwrap();
540
541 let combo = i1.convex_combine(&i2, 0.5).unwrap();
542 assert!((combo.lower - 0.4).abs() < 1e-10); assert!((combo.upper - 0.6).abs() < 1e-10); }
545
546 #[test]
547 fn test_propagate_probabilities_and() {
548 let mut prob_map = HashMap::new();
549 prob_map.insert("P".to_string(), ProbabilityInterval::new(0.4, 0.6).unwrap());
550 prob_map.insert("Q".to_string(), ProbabilityInterval::new(0.5, 0.8).unwrap());
551
552 let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::pred("Q", vec![]));
553
554 let result = propagate_probabilities(&expr, &prob_map);
555 assert_eq!(result.lower, 0.0);
556 assert_eq!(result.upper, 0.6);
557 }
558
559 #[test]
560 fn test_propagate_probabilities_or() {
561 let mut prob_map = HashMap::new();
562 prob_map.insert("P".to_string(), ProbabilityInterval::new(0.4, 0.6).unwrap());
563 prob_map.insert("Q".to_string(), ProbabilityInterval::new(0.5, 0.8).unwrap());
564
565 let expr = TLExpr::or(TLExpr::pred("P", vec![]), TLExpr::pred("Q", vec![]));
566
567 let result = propagate_probabilities(&expr, &prob_map);
568 assert_eq!(result.lower, 0.5);
569 assert_eq!(result.upper, 1.0);
570 }
571
572 #[test]
573 fn test_propagate_probabilities_not() {
574 let mut prob_map = HashMap::new();
575 prob_map.insert("P".to_string(), ProbabilityInterval::new(0.3, 0.7).unwrap());
576
577 let expr = TLExpr::negate(TLExpr::pred("P", vec![]));
578
579 let result = propagate_probabilities(&expr, &prob_map);
580 assert!((result.lower - 0.3).abs() < 1e-10);
581 assert!((result.upper - 0.7).abs() < 1e-10);
582 }
583
584 #[test]
585 fn test_weighted_rule_propagation() {
586 let mut prob_map = HashMap::new();
587 prob_map.insert("P".to_string(), ProbabilityInterval::new(0.5, 0.8).unwrap());
588
589 let expr = TLExpr::weighted_rule(0.5, TLExpr::pred("P", vec![]));
590
591 let result = propagate_probabilities(&expr, &prob_map);
592 assert_eq!(result.lower, 0.25); assert_eq!(result.upper, 0.4); }
595
596 #[test]
597 fn test_probabilistic_choice() {
598 let mut prob_map = HashMap::new();
599 prob_map.insert("P".to_string(), ProbabilityInterval::precise(0.6).unwrap());
600 prob_map.insert("Q".to_string(), ProbabilityInterval::precise(0.4).unwrap());
601
602 let expr = TLExpr::probabilistic_choice(vec![
603 (0.5, TLExpr::pred("P", vec![])),
604 (0.5, TLExpr::pred("Q", vec![])),
605 ]);
606
607 let result = propagate_probabilities(&expr, &prob_map);
608 assert_eq!(result.lower, 0.5);
610 assert_eq!(result.upper, 0.5);
611 }
612
613 #[test]
614 fn test_credal_set() {
615 let mut dist1 = HashMap::new();
616 dist1.insert("A".to_string(), 0.3);
617 dist1.insert("B".to_string(), 0.7);
618
619 let mut dist2 = HashMap::new();
620 dist2.insert("A".to_string(), 0.6);
621 dist2.insert("B".to_string(), 0.4);
622
623 let credal = CredalSet::new(vec![dist1, dist2]);
624
625 assert_eq!(credal.lower_prob("A"), 0.3);
626 assert_eq!(credal.upper_prob("A"), 0.6);
627 assert!(!credal.is_precise());
628 }
629
630 #[test]
631 fn test_mln_probability() {
632 let rule = TLExpr::pred("P", vec![]);
633 let weights = vec![(2.0, rule.clone())];
634
635 let mut evidence = HashMap::new();
636 evidence.insert("P".to_string(), true);
637
638 let prob = mln_probability(&rule, &weights, &evidence);
639 assert!((prob - 0.88).abs() < 0.01);
641 }
642}