Skip to main content

tensorlogic_ir/expr/
probabilistic_reasoning.rs

1//! Probabilistic reasoning and probability bounds propagation.
2//!
3//! This module implements probabilistic inference and uncertainty quantification including:
4//! - Probability interval arithmetic (Fréchet bounds)
5//! - Imprecise probabilities (lower/upper bounds)
6//! - Credal sets and convex sets of probability distributions
7//! - Probabilistic semantics for weighted rules
8//! - Probability propagation through logical connectives
9//!
10//! # Applications
11//! - Markov Logic Networks (MLNs)
12//! - Probabilistic Logic Programs
13//! - Bayesian inference with interval probabilities
14//! - Uncertainty quantification under incomplete information
15
16use super::TLExpr;
17use std::collections::HashMap;
18
19/// Probability interval representing imprecise probabilities.
20///
21/// Represents the set [lower, upper] of possible probability values.
22/// Follows the theory of imprecise probabilities and credal sets.
23#[derive(Debug, Clone, Copy, PartialEq)]
24pub struct ProbabilityInterval {
25    /// Lower probability bound (must be in [0, 1])
26    pub lower: f64,
27    /// Upper probability bound (must be in [0, 1] and >= lower)
28    pub upper: f64,
29}
30
31impl ProbabilityInterval {
32    /// Create a new probability interval.
33    ///
34    /// Returns None if bounds are invalid (not in \[0,1\] or lower > upper).
35    pub fn new(lower: f64, upper: f64) -> Option<Self> {
36        if lower < 0.0 || upper > 1.0 || lower > upper {
37            None
38        } else {
39            Some(Self { lower, upper })
40        }
41    }
42
43    /// Create a precise probability (point interval).
44    pub fn precise(prob: f64) -> Option<Self> {
45        Self::new(prob, prob)
46    }
47
48    /// Create a vacuous interval [0, 1] (complete ignorance).
49    pub fn vacuous() -> Self {
50        Self {
51            lower: 0.0,
52            upper: 1.0,
53        }
54    }
55
56    /// Width of the interval (measure of imprecision).
57    pub fn width(&self) -> f64 {
58        self.upper - self.lower
59    }
60
61    /// Check if this is a precise probability.
62    pub fn is_precise(&self) -> bool {
63        (self.upper - self.lower).abs() < 1e-10
64    }
65
66    /// Check if the interval is vacuous (completely imprecise).
67    pub fn is_vacuous(&self) -> bool {
68        self.lower == 0.0 && self.upper == 1.0
69    }
70
71    /// Complement: P(¬A) given P(A).
72    pub fn complement(&self) -> Self {
73        Self {
74            lower: 1.0 - self.upper,
75            upper: 1.0 - self.lower,
76        }
77    }
78
79    /// Conjunction bounds: P(A ∧ B) given P(A) and P(B).
80    ///
81    /// Uses Fréchet bounds: max(0, P(A) + P(B) - 1) ≤ P(A ∧ B) ≤ min(P(A), P(B))
82    pub fn and(&self, other: &Self) -> Self {
83        let lower = (self.lower + other.lower - 1.0).max(0.0);
84        let upper = self.upper.min(other.upper);
85        Self { lower, upper }
86    }
87
88    /// Disjunction bounds: P(A ∨ B) given P(A) and P(B).
89    ///
90    /// Uses Fréchet bounds: max(P(A), P(B)) ≤ P(A ∨ B) ≤ min(1, P(A) + P(B))
91    pub fn or(&self, other: &Self) -> Self {
92        let lower = self.lower.max(other.lower);
93        let upper = (self.upper + other.upper).min(1.0);
94        Self { lower, upper }
95    }
96
97    /// Implication bounds: P(A → B) given P(A) and P(B).
98    ///
99    /// A → B ≡ ¬A ∨ B, so use complement and disjunction.
100    pub fn implies(&self, other: &Self) -> Self {
101        self.complement().or(other)
102    }
103
104    /// Conditional probability bounds: P(B|A) given P(A) and P(A ∧ B).
105    ///
106    /// If P(A) > 0, returns P(A ∧ B) / P(A).
107    /// Uses interval division: \[a,b\] / \[c,d\] = \[a/d, b/c\] for positive intervals.
108    pub fn conditional(&self, joint: &Self) -> Option<Self> {
109        if self.upper == 0.0 {
110            // Cannot condition on zero probability event
111            None
112        } else if self.lower == 0.0 {
113            // Lower bound might be zero, use conservative bounds
114            Some(Self {
115                lower: 0.0,
116                upper: joint.upper / self.lower.max(1e-10),
117            })
118        } else {
119            Some(Self {
120                lower: joint.lower / self.upper,
121                upper: joint.upper / self.lower,
122            })
123        }
124    }
125
126    /// Intersection of two probability intervals.
127    ///
128    /// Returns None if intervals don't overlap.
129    pub fn intersect(&self, other: &Self) -> Option<Self> {
130        let lower = self.lower.max(other.lower);
131        let upper = self.upper.min(other.upper);
132        if lower <= upper {
133            Some(Self { lower, upper })
134        } else {
135            None
136        }
137    }
138
139    /// Convex combination of two intervals.
140    ///
141    /// Useful for averaging or mixing probability assessments.
142    pub fn convex_combine(&self, other: &Self, weight: f64) -> Option<Self> {
143        if !(0.0..=1.0).contains(&weight) {
144            return None;
145        }
146        Some(Self {
147            lower: self.lower * weight + other.lower * (1.0 - weight),
148            upper: self.upper * weight + other.upper * (1.0 - weight),
149        })
150    }
151}
152
153/// Credal set: convex set of probability distributions.
154///
155/// Represented by extreme points (vertices) of the credal set.
156#[derive(Debug, Clone)]
157pub struct CredalSet {
158    /// Extreme probability distributions (each sums to 1)
159    extreme_points: Vec<HashMap<String, f64>>,
160}
161
162impl CredalSet {
163    /// Create a credal set from extreme points.
164    pub fn new(extreme_points: Vec<HashMap<String, f64>>) -> Self {
165        Self { extreme_points }
166    }
167
168    /// Create a precise credal set (single distribution).
169    pub fn precise(distribution: HashMap<String, f64>) -> Self {
170        Self {
171            extreme_points: vec![distribution],
172        }
173    }
174
175    /// Get lower probability bound for an event.
176    pub fn lower_prob(&self, event: &str) -> f64 {
177        self.extreme_points
178            .iter()
179            .filter_map(|dist| dist.get(event).copied())
180            .fold(f64::INFINITY, f64::min)
181    }
182
183    /// Get upper probability bound for an event.
184    pub fn upper_prob(&self, event: &str) -> f64 {
185        self.extreme_points
186            .iter()
187            .filter_map(|dist| dist.get(event).copied())
188            .fold(f64::NEG_INFINITY, f64::max)
189    }
190
191    /// Get probability interval for an event.
192    pub fn prob_interval(&self, event: &str) -> ProbabilityInterval {
193        ProbabilityInterval {
194            lower: self.lower_prob(event),
195            upper: self.upper_prob(event),
196        }
197    }
198
199    /// Number of extreme points in the credal set.
200    pub fn size(&self) -> usize {
201        self.extreme_points.len()
202    }
203
204    /// Check if credal set is precise (single distribution).
205    pub fn is_precise(&self) -> bool {
206        self.extreme_points.len() == 1
207    }
208}
209
210/// Propagate probability intervals through a logical expression.
211///
212/// Given probability assignments to atomic predicates, computes
213/// probability bounds for the compound expression.
214pub fn propagate_probabilities(
215    expr: &TLExpr,
216    prob_map: &HashMap<String, ProbabilityInterval>,
217) -> ProbabilityInterval {
218    match expr {
219        TLExpr::Pred { name, .. } => prob_map
220            .get(name)
221            .copied()
222            .unwrap_or_else(ProbabilityInterval::vacuous),
223
224        TLExpr::Constant(v) => {
225            if *v >= 1.0 {
226                ProbabilityInterval::precise(1.0).unwrap()
227            } else if *v <= 0.0 {
228                ProbabilityInterval::precise(0.0).unwrap()
229            } else {
230                ProbabilityInterval::vacuous()
231            }
232        }
233
234        TLExpr::And(left, right) => {
235            let left_prob = propagate_probabilities(left, prob_map);
236            let right_prob = propagate_probabilities(right, prob_map);
237            left_prob.and(&right_prob)
238        }
239
240        TLExpr::Or(left, right) => {
241            let left_prob = propagate_probabilities(left, prob_map);
242            let right_prob = propagate_probabilities(right, prob_map);
243            left_prob.or(&right_prob)
244        }
245
246        TLExpr::Not(inner) => {
247            let inner_prob = propagate_probabilities(inner, prob_map);
248            inner_prob.complement()
249        }
250
251        TLExpr::Imply(premise, conclusion) => {
252            let premise_prob = propagate_probabilities(premise, prob_map);
253            let conclusion_prob = propagate_probabilities(conclusion, prob_map);
254            premise_prob.implies(&conclusion_prob)
255        }
256
257        // For weighted rules, the weight represents confidence
258        TLExpr::WeightedRule { weight, rule } => {
259            let rule_prob = propagate_probabilities(rule, prob_map);
260            // Weight modulates the probability bounds
261            ProbabilityInterval {
262                lower: rule_prob.lower * weight,
263                upper: rule_prob.upper * weight,
264            }
265        }
266
267        // For probabilistic choice, compute expected bounds
268        TLExpr::ProbabilisticChoice { alternatives } => {
269            let mut lower_sum = 0.0;
270            let mut upper_sum = 0.0;
271            let mut total_weight = 0.0;
272
273            for (prob, expr) in alternatives {
274                let expr_interval = propagate_probabilities(expr, prob_map);
275                lower_sum += prob * expr_interval.lower;
276                upper_sum += prob * expr_interval.upper;
277                total_weight += prob;
278            }
279
280            // Normalize if weights don't sum to 1
281            if total_weight > 0.0 && (total_weight - 1.0).abs() > 1e-10 {
282                lower_sum /= total_weight;
283                upper_sum /= total_weight;
284            }
285
286            ProbabilityInterval {
287                lower: lower_sum.clamp(0.0, 1.0),
288                upper: upper_sum.clamp(0.0, 1.0),
289            }
290        }
291
292        // Default: vacuous interval (no information)
293        _ => ProbabilityInterval::vacuous(),
294    }
295}
296
297/// Compute tightest probability bounds for an expression using optimization.
298///
299/// This uses linear programming to find the tightest possible bounds
300/// given constraints. For now, uses a simple iterative tightening approach.
301pub fn compute_tight_bounds(
302    expr: &TLExpr,
303    prob_map: &HashMap<String, ProbabilityInterval>,
304) -> ProbabilityInterval {
305    // Start with Fréchet bounds
306    let mut current = propagate_probabilities(expr, prob_map);
307
308    // Iteratively tighten bounds by considering dependencies
309    // For simplicity, we do 3 iterations (could be made configurable)
310    for _ in 0..3 {
311        current = tighten_iteration(expr, prob_map, &current);
312    }
313
314    current
315}
316
317fn tighten_iteration(
318    expr: &TLExpr,
319    prob_map: &HashMap<String, ProbabilityInterval>,
320    current: &ProbabilityInterval,
321) -> ProbabilityInterval {
322    match expr {
323        TLExpr::And(left, right) => {
324            let left_prob = compute_tight_bounds(left, prob_map);
325            let right_prob = compute_tight_bounds(right, prob_map);
326
327            // Tighten using independence assumption if possible
328            let mut result = left_prob.and(&right_prob);
329
330            // Additional tightening: if we know the result bounds, constrain components
331            if let Some(intersection) = result.intersect(current) {
332                result = intersection;
333            }
334
335            result
336        }
337
338        TLExpr::Or(left, right) => {
339            let left_prob = compute_tight_bounds(left, prob_map);
340            let right_prob = compute_tight_bounds(right, prob_map);
341
342            let mut result = left_prob.or(&right_prob);
343
344            if let Some(intersection) = result.intersect(current) {
345                result = intersection;
346            }
347
348            result
349        }
350
351        _ => propagate_probabilities(expr, prob_map),
352    }
353}
354
355/// Extract probabilistic semantics from weighted rules.
356///
357/// Converts weighted rules into probability distributions over possible worlds.
358pub fn extract_probabilistic_semantics(expr: &TLExpr) -> Vec<(f64, TLExpr)> {
359    let mut weighted_rules = Vec::new();
360    extract_weighted_rec(expr, &mut weighted_rules);
361    weighted_rules
362}
363
364fn extract_weighted_rec(expr: &TLExpr, result: &mut Vec<(f64, TLExpr)>) {
365    match expr {
366        TLExpr::WeightedRule { weight, rule } => {
367            result.push((*weight, (**rule).clone()));
368            extract_weighted_rec(rule, result);
369        }
370
371        TLExpr::ProbabilisticChoice { alternatives } => {
372            for (prob, expr) in alternatives {
373                result.push((*prob, expr.clone()));
374                extract_weighted_rec(expr, result);
375            }
376        }
377
378        TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
379            extract_weighted_rec(l, result);
380            extract_weighted_rec(r, result);
381        }
382
383        TLExpr::Not(e) => extract_weighted_rec(e, result),
384
385        _ => {}
386    }
387}
388
389/// Compute probability of an expression under a Markov Logic Network (MLN) semantics.
390///
391/// MLN uses weighted rules where weight w corresponds to log-odds ratio.
392/// P(world) ∝ exp(∑ w_i * n_i) where n_i is number of groundings satisfied.
393pub fn mln_probability(
394    _expr: &TLExpr,
395    weights: &[(f64, TLExpr)],
396    evidence: &HashMap<String, bool>,
397) -> f64 {
398    // Simplified MLN: compute unnormalized probability
399    let mut total_weight = 0.0;
400
401    for (weight, rule) in weights {
402        if evaluates_true(rule, evidence) {
403            total_weight += weight;
404        }
405    }
406
407    // Logistic function to get probability
408    1.0 / (1.0 + (-total_weight).exp())
409}
410
411/// Simple boolean evaluation for ground facts.
412fn evaluates_true(expr: &TLExpr, evidence: &HashMap<String, bool>) -> bool {
413    match expr {
414        TLExpr::Pred { name, .. } => evidence.get(name).copied().unwrap_or(false),
415
416        TLExpr::And(l, r) => evaluates_true(l, evidence) && evaluates_true(r, evidence),
417
418        TLExpr::Or(l, r) => evaluates_true(l, evidence) || evaluates_true(r, evidence),
419
420        TLExpr::Not(e) => !evaluates_true(e, evidence),
421
422        TLExpr::Imply(l, r) => !evaluates_true(l, evidence) || evaluates_true(r, evidence),
423
424        TLExpr::Constant(v) => *v >= 1.0,
425
426        _ => false,
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_probability_interval_creation() {
436        let interval = ProbabilityInterval::new(0.3, 0.7).unwrap();
437        assert!((interval.lower - 0.3).abs() < 1e-10);
438        assert!((interval.upper - 0.7).abs() < 1e-10);
439        assert!((interval.width() - 0.4).abs() < 1e-10);
440
441        // Invalid intervals
442        assert!(ProbabilityInterval::new(-0.1, 0.5).is_none());
443        assert!(ProbabilityInterval::new(0.8, 0.5).is_none());
444        assert!(ProbabilityInterval::new(0.5, 1.5).is_none());
445    }
446
447    #[test]
448    fn test_precise_probability() {
449        let precise = ProbabilityInterval::precise(0.5).unwrap();
450        assert!(precise.is_precise());
451        assert_eq!(precise.width(), 0.0);
452    }
453
454    #[test]
455    fn test_vacuous_interval() {
456        let vacuous = ProbabilityInterval::vacuous();
457        assert!(vacuous.is_vacuous());
458        assert_eq!(vacuous.width(), 1.0);
459    }
460
461    #[test]
462    fn test_complement() {
463        let interval = ProbabilityInterval::new(0.3, 0.7).unwrap();
464        let complement = interval.complement();
465        assert!((complement.lower - 0.3).abs() < 1e-10);
466        assert!((complement.upper - 0.7).abs() < 1e-10);
467    }
468
469    #[test]
470    fn test_frechet_and() {
471        let p_a = ProbabilityInterval::new(0.4, 0.6).unwrap();
472        let p_b = ProbabilityInterval::new(0.5, 0.8).unwrap();
473        let p_and = p_a.and(&p_b);
474
475        // Lower: max(0, 0.4 + 0.5 - 1) = 0.0
476        assert_eq!(p_and.lower, 0.0);
477        // Upper: min(0.6, 0.8) = 0.6
478        assert_eq!(p_and.upper, 0.6);
479    }
480
481    #[test]
482    fn test_frechet_or() {
483        let p_a = ProbabilityInterval::new(0.4, 0.6).unwrap();
484        let p_b = ProbabilityInterval::new(0.5, 0.8).unwrap();
485        let p_or = p_a.or(&p_b);
486
487        // Lower: max(0.4, 0.5) = 0.5
488        assert_eq!(p_or.lower, 0.5);
489        // Upper: min(1, 0.6 + 0.8) = 1.0
490        assert_eq!(p_or.upper, 1.0);
491    }
492
493    #[test]
494    fn test_implication_bounds() {
495        let p_a = ProbabilityInterval::new(0.3, 0.5).unwrap();
496        let p_b = ProbabilityInterval::new(0.6, 0.9).unwrap();
497        let p_implies = p_a.implies(&p_b);
498
499        // A -> B ≡ ¬A ∨ B
500        let not_a = p_a.complement();
501        let expected = not_a.or(&p_b);
502
503        assert_eq!(p_implies.lower, expected.lower);
504        assert_eq!(p_implies.upper, expected.upper);
505    }
506
507    #[test]
508    fn test_conditional_probability() {
509        let p_a = ProbabilityInterval::new(0.4, 0.6).unwrap();
510        let p_a_and_b = ProbabilityInterval::new(0.2, 0.3).unwrap();
511
512        let p_b_given_a = p_a.conditional(&p_a_and_b).unwrap();
513
514        // P(B|A) = P(A ∧ B) / P(A)
515        // Lower: 0.2 / 0.6 = 0.333...
516        // Upper: 0.3 / 0.4 = 0.75
517        assert!((p_b_given_a.lower - 0.333).abs() < 0.01);
518        assert!((p_b_given_a.upper - 0.75).abs() < 1e-10);
519    }
520
521    #[test]
522    fn test_interval_intersection() {
523        let i1 = ProbabilityInterval::new(0.2, 0.7).unwrap();
524        let i2 = ProbabilityInterval::new(0.5, 0.9).unwrap();
525
526        let intersection = i1.intersect(&i2).unwrap();
527        assert_eq!(intersection.lower, 0.5);
528        assert_eq!(intersection.upper, 0.7);
529
530        // No intersection
531        let i3 = ProbabilityInterval::new(0.1, 0.3).unwrap();
532        let i4 = ProbabilityInterval::new(0.6, 0.9).unwrap();
533        assert!(i3.intersect(&i4).is_none());
534    }
535
536    #[test]
537    fn test_convex_combination() {
538        let i1 = ProbabilityInterval::new(0.2, 0.4).unwrap();
539        let i2 = ProbabilityInterval::new(0.6, 0.8).unwrap();
540
541        let combo = i1.convex_combine(&i2, 0.5).unwrap();
542        assert!((combo.lower - 0.4).abs() < 1e-10); // 0.2 * 0.5 + 0.6 * 0.5
543        assert!((combo.upper - 0.6).abs() < 1e-10); // 0.4 * 0.5 + 0.8 * 0.5
544    }
545
546    #[test]
547    fn test_propagate_probabilities_and() {
548        let mut prob_map = HashMap::new();
549        prob_map.insert("P".to_string(), ProbabilityInterval::new(0.4, 0.6).unwrap());
550        prob_map.insert("Q".to_string(), ProbabilityInterval::new(0.5, 0.8).unwrap());
551
552        let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::pred("Q", vec![]));
553
554        let result = propagate_probabilities(&expr, &prob_map);
555        assert_eq!(result.lower, 0.0);
556        assert_eq!(result.upper, 0.6);
557    }
558
559    #[test]
560    fn test_propagate_probabilities_or() {
561        let mut prob_map = HashMap::new();
562        prob_map.insert("P".to_string(), ProbabilityInterval::new(0.4, 0.6).unwrap());
563        prob_map.insert("Q".to_string(), ProbabilityInterval::new(0.5, 0.8).unwrap());
564
565        let expr = TLExpr::or(TLExpr::pred("P", vec![]), TLExpr::pred("Q", vec![]));
566
567        let result = propagate_probabilities(&expr, &prob_map);
568        assert_eq!(result.lower, 0.5);
569        assert_eq!(result.upper, 1.0);
570    }
571
572    #[test]
573    fn test_propagate_probabilities_not() {
574        let mut prob_map = HashMap::new();
575        prob_map.insert("P".to_string(), ProbabilityInterval::new(0.3, 0.7).unwrap());
576
577        let expr = TLExpr::negate(TLExpr::pred("P", vec![]));
578
579        let result = propagate_probabilities(&expr, &prob_map);
580        assert!((result.lower - 0.3).abs() < 1e-10);
581        assert!((result.upper - 0.7).abs() < 1e-10);
582    }
583
584    #[test]
585    fn test_weighted_rule_propagation() {
586        let mut prob_map = HashMap::new();
587        prob_map.insert("P".to_string(), ProbabilityInterval::new(0.5, 0.8).unwrap());
588
589        let expr = TLExpr::weighted_rule(0.5, TLExpr::pred("P", vec![]));
590
591        let result = propagate_probabilities(&expr, &prob_map);
592        assert_eq!(result.lower, 0.25); // 0.5 * 0.5
593        assert_eq!(result.upper, 0.4); // 0.5 * 0.8
594    }
595
596    #[test]
597    fn test_probabilistic_choice() {
598        let mut prob_map = HashMap::new();
599        prob_map.insert("P".to_string(), ProbabilityInterval::precise(0.6).unwrap());
600        prob_map.insert("Q".to_string(), ProbabilityInterval::precise(0.4).unwrap());
601
602        let expr = TLExpr::probabilistic_choice(vec![
603            (0.5, TLExpr::pred("P", vec![])),
604            (0.5, TLExpr::pred("Q", vec![])),
605        ]);
606
607        let result = propagate_probabilities(&expr, &prob_map);
608        // Expected: 0.5 * 0.6 + 0.5 * 0.4 = 0.5
609        assert_eq!(result.lower, 0.5);
610        assert_eq!(result.upper, 0.5);
611    }
612
613    #[test]
614    fn test_credal_set() {
615        let mut dist1 = HashMap::new();
616        dist1.insert("A".to_string(), 0.3);
617        dist1.insert("B".to_string(), 0.7);
618
619        let mut dist2 = HashMap::new();
620        dist2.insert("A".to_string(), 0.6);
621        dist2.insert("B".to_string(), 0.4);
622
623        let credal = CredalSet::new(vec![dist1, dist2]);
624
625        assert_eq!(credal.lower_prob("A"), 0.3);
626        assert_eq!(credal.upper_prob("A"), 0.6);
627        assert!(!credal.is_precise());
628    }
629
630    #[test]
631    fn test_mln_probability() {
632        let rule = TLExpr::pred("P", vec![]);
633        let weights = vec![(2.0, rule.clone())];
634
635        let mut evidence = HashMap::new();
636        evidence.insert("P".to_string(), true);
637
638        let prob = mln_probability(&rule, &weights, &evidence);
639        // exp(2) / (1 + exp(2)) ≈ 0.88
640        assert!((prob - 0.88).abs() < 0.01);
641    }
642}