Skip to main content

scirs2_stats/bayesian_network/
cpd.rs

1//! Conditional Probability Distributions (CPDs) for Bayesian Networks.
2//!
3//! Provides:
4//! - [`CPD`] trait — common interface for all CPDs
5//! - [`TabularCPD`] — discrete CPD stored as a table
6//! - [`GaussianCPD`] — linear Gaussian CPD
7//! - [`MixtureCPD`] — mixture of TabularCPDs
8//! - [`ConditionalLinear`] — conditional linear Gaussian for continuous parents
9
10use crate::StatsError;
11use std::f64::consts::PI;
12
13// ---------------------------------------------------------------------------
14// Trait
15// ---------------------------------------------------------------------------
16
17/// Trait for conditional probability distributions.
18pub trait CPD: Send + Sync {
19    /// Name of the node this CPD belongs to.
20    fn node(&self) -> usize;
21
22    /// Evaluate the (unnormalised) probability (or density) of `value` given
23    /// `parent_values`.
24    ///
25    /// For discrete CPDs this returns the exact probability P(X=value | pa(X) = parent_values).
26    /// For continuous CPDs this returns the probability density.
27    fn prob(&self, value: usize, parent_values: &[usize]) -> f64;
28
29    /// Number of states (cardinality) for discrete nodes; 0 for continuous.
30    fn cardinality(&self) -> usize;
31
32    /// Parent node indices.
33    fn parent_indices(&self) -> &[usize];
34
35    /// Whether this CPD is continuous.
36    fn is_continuous(&self) -> bool {
37        false
38    }
39
40    /// Log-probability: default implementation wraps `prob`.
41    fn log_prob(&self, value: usize, parent_values: &[usize]) -> f64 {
42        let p = self.prob(value, parent_values);
43        if p <= 0.0 {
44            f64::NEG_INFINITY
45        } else {
46            p.ln()
47        }
48    }
49}
50
51// ---------------------------------------------------------------------------
52// TabularCPD
53// ---------------------------------------------------------------------------
54
55/// Discrete CPD stored as a conditional probability table (CPT).
56///
57/// The table is indexed by the combined parent configuration.
58/// `values[row][val]` = P(X = val | parent_config = row).
59///
60/// Row indexing follows *column-major* (rightmost parent cycles fastest),
61/// matching pgmpy convention:
62///   `row = Sum_i (parent_values[i] * stride[i])`
63/// where `stride[i]` = product of cardinalities of parents to the right.
64#[derive(Debug, Clone)]
65pub struct TabularCPD {
66    /// Index of the node this CPD belongs to.
67    pub node_idx: usize,
68    /// Number of values (cardinality) of this node.
69    pub n_values: usize,
70    /// Cardinalities of parent nodes.
71    pub parent_card: Vec<usize>,
72    /// Parent node indices.
73    pub parent_indices: Vec<usize>,
74    /// Table: `table[row]` = probability distribution over `n_values` states.
75    /// Length = product(parent_card), each inner vec has length n_values.
76    pub table: Vec<Vec<f64>>,
77    /// Strides for row computation.
78    strides: Vec<usize>,
79}
80
81impl TabularCPD {
82    /// Create a new TabularCPD.
83    ///
84    /// # Arguments
85    /// - `node_idx`: Index of the node.
86    /// - `n_values`: Cardinality of this node.
87    /// - `parent_indices`: Indices of parent nodes.
88    /// - `parent_card`: Cardinalities of each parent (same order as parent_indices).
89    /// - `values`: Probability table. Each row is a probability distribution.
90    ///   If there are no parents, `values` should have exactly one row.
91    pub fn new(
92        node_idx: usize,
93        n_values: usize,
94        parent_indices: Vec<usize>,
95        parent_card: Vec<usize>,
96        values: Vec<Vec<f64>>,
97    ) -> Result<Self, StatsError> {
98        if parent_indices.len() != parent_card.len() {
99            return Err(StatsError::InvalidInput(
100                "parent_indices and parent_card must have the same length".to_string(),
101            ));
102        }
103        let n_rows: usize = if parent_card.is_empty() {
104            1
105        } else {
106            parent_card.iter().product()
107        };
108        if values.len() != n_rows {
109            return Err(StatsError::InvalidInput(format!(
110                "Expected {n_rows} rows (product of parent cardinalities), got {}",
111                values.len()
112            )));
113        }
114        for (i, row) in values.iter().enumerate() {
115            if row.len() != n_values {
116                return Err(StatsError::InvalidInput(format!(
117                    "Row {i} has {} values, expected {n_values}",
118                    row.len()
119                )));
120            }
121            let sum: f64 = row.iter().sum();
122            if (sum - 1.0).abs() > 1e-6 {
123                return Err(StatsError::InvalidInput(format!(
124                    "Row {i} does not sum to 1.0 (sum={sum:.6})"
125                )));
126            }
127        }
128        // Compute strides: stride[i] = product(parent_card[i+1..])
129        let strides = compute_strides(&parent_card);
130        Ok(Self {
131            node_idx,
132            n_values,
133            parent_card,
134            parent_indices,
135            table: values,
136            strides,
137        })
138    }
139
140    /// Compute the row index for a given parent configuration.
141    pub fn row_index(&self, parent_values: &[usize]) -> Result<usize, StatsError> {
142        if parent_values.len() != self.parent_card.len() {
143            return Err(StatsError::InvalidInput(format!(
144                "Expected {} parent values, got {}",
145                self.parent_card.len(),
146                parent_values.len()
147            )));
148        }
149        let mut row = 0usize;
150        for (i, &pv) in parent_values.iter().enumerate() {
151            if pv >= self.parent_card[i] {
152                return Err(StatsError::InvalidInput(format!(
153                    "Parent {i} value {pv} out of range (card={})",
154                    self.parent_card[i]
155                )));
156            }
157            row += pv * self.strides[i];
158        }
159        Ok(row)
160    }
161
162    /// Return the full conditional distribution P(X | parent_values).
163    pub fn distribution(&self, parent_values: &[usize]) -> Result<&[f64], StatsError> {
164        let row = self.row_index(parent_values)?;
165        Ok(&self.table[row])
166    }
167}
168
169impl CPD for TabularCPD {
170    fn node(&self) -> usize {
171        self.node_idx
172    }
173
174    fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
175        if value >= self.n_values {
176            return 0.0;
177        }
178        let row = match self.row_index(parent_values) {
179            Ok(r) => r,
180            Err(_) => return 0.0,
181        };
182        self.table[row][value]
183    }
184
185    fn cardinality(&self) -> usize {
186        self.n_values
187    }
188
189    fn parent_indices(&self) -> &[usize] {
190        &self.parent_indices
191    }
192}
193
194// ---------------------------------------------------------------------------
195// GaussianCPD
196// ---------------------------------------------------------------------------
197
198/// Linear Gaussian CPD: X | pa(X) ~ N(mu + beta^T * pa(X), sigma^2).
199///
200/// For a root node (no parents), this is simply N(mu, sigma^2).
201#[derive(Debug, Clone)]
202pub struct GaussianCPD {
203    /// Index of this node.
204    pub node_idx: usize,
205    /// Intercept (mean when all parents are 0).
206    pub mu: f64,
207    /// Noise standard deviation.
208    pub sigma: f64,
209    /// Regression coefficients for each parent.
210    pub beta: Vec<f64>,
211    /// Parent node indices.
212    pub parent_indices: Vec<usize>,
213}
214
215impl GaussianCPD {
216    /// Create a new GaussianCPD.
217    pub fn new(
218        node_idx: usize,
219        mu: f64,
220        sigma: f64,
221        beta: Vec<f64>,
222        parent_indices: Vec<usize>,
223    ) -> Result<Self, StatsError> {
224        if sigma <= 0.0 {
225            return Err(StatsError::InvalidInput(format!(
226                "sigma must be positive, got {sigma}"
227            )));
228        }
229        if beta.len() != parent_indices.len() {
230            return Err(StatsError::InvalidInput(
231                "beta and parent_indices must have the same length".to_string(),
232            ));
233        }
234        Ok(Self {
235            node_idx,
236            mu,
237            sigma,
238            beta,
239            parent_indices,
240        })
241    }
242
243    /// Compute the conditional mean given parent values (as continuous f64).
244    pub fn conditional_mean(&self, parent_vals: &[f64]) -> f64 {
245        self.mu
246            + self
247                .beta
248                .iter()
249                .zip(parent_vals)
250                .map(|(b, v)| b * v)
251                .sum::<f64>()
252    }
253
254    /// Compute the conditional density p(x | pa(X)) given continuous value x.
255    pub fn density(&self, x: f64, parent_vals: &[f64]) -> f64 {
256        let mean = self.conditional_mean(parent_vals);
257        let z = (x - mean) / self.sigma;
258        (-0.5 * z * z).exp() / (self.sigma * (2.0 * PI).sqrt())
259    }
260}
261
262impl CPD for GaussianCPD {
263    fn node(&self) -> usize {
264        self.node_idx
265    }
266
267    /// Returns density evaluated at `value` (cast to f64) with `parent_values` cast to f64.
268    fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
269        let pv: Vec<f64> = parent_values.iter().map(|&v| v as f64).collect();
270        self.density(value as f64, &pv)
271    }
272
273    fn cardinality(&self) -> usize {
274        0 // continuous
275    }
276
277    fn parent_indices(&self) -> &[usize] {
278        &self.parent_indices
279    }
280
281    fn is_continuous(&self) -> bool {
282        true
283    }
284}
285
286// ---------------------------------------------------------------------------
287// MixtureCPD
288// ---------------------------------------------------------------------------
289
290/// Mixture of TabularCPDs.
291///
292/// P(X | pa(X)) = Σ_k w_k * P_k(X | pa(X))
293#[derive(Debug, Clone)]
294pub struct MixtureCPD {
295    /// Index of this node.
296    pub node_idx: usize,
297    /// Component CPDs.
298    pub components: Vec<TabularCPD>,
299    /// Mixture weights (must sum to 1).
300    pub weights: Vec<f64>,
301}
302
303impl MixtureCPD {
304    /// Create a new MixtureCPD.
305    pub fn new(
306        node_idx: usize,
307        components: Vec<TabularCPD>,
308        weights: Vec<f64>,
309    ) -> Result<Self, StatsError> {
310        if components.is_empty() {
311            return Err(StatsError::InvalidInput(
312                "MixtureCPD needs at least one component".to_string(),
313            ));
314        }
315        if components.len() != weights.len() {
316            return Err(StatsError::InvalidInput(
317                "components and weights must have the same length".to_string(),
318            ));
319        }
320        let wsum: f64 = weights.iter().sum();
321        if (wsum - 1.0).abs() > 1e-6 {
322            return Err(StatsError::InvalidInput(format!(
323                "weights must sum to 1.0 (got {wsum:.6})"
324            )));
325        }
326        for w in &weights {
327            if *w < 0.0 {
328                return Err(StatsError::InvalidInput(
329                    "weights must be non-negative".to_string(),
330                ));
331            }
332        }
333        Ok(Self {
334            node_idx,
335            components,
336            weights,
337        })
338    }
339}
340
341impl CPD for MixtureCPD {
342    fn node(&self) -> usize {
343        self.node_idx
344    }
345
346    fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
347        self.components
348            .iter()
349            .zip(&self.weights)
350            .map(|(c, w)| w * c.prob(value, parent_values))
351            .sum()
352    }
353
354    fn cardinality(&self) -> usize {
355        self.components[0].cardinality()
356    }
357
358    fn parent_indices(&self) -> &[usize] {
359        self.components[0].parent_indices()
360    }
361}
362
363// ---------------------------------------------------------------------------
364// ConditionalLinear
365// ---------------------------------------------------------------------------
366
367/// Conditional Linear Gaussian for continuous parents and discrete output.
368///
369/// `P(X=k | pa(X)) = softmax(W[k] * pa(X) + b[k])`
370/// `sigma[k]` stores standard deviations (for density evaluation).
371#[derive(Debug, Clone)]
372pub struct ConditionalLinear {
373    /// Index of this node.
374    pub node_idx: usize,
375    /// Weight matrix: `W[k]` has length = number of parents.
376    pub w: Vec<Vec<f64>>,
377    /// Bias vector: `b[k]` for each class k.
378    pub b: Vec<f64>,
379    /// Standard deviations (used if output is also continuous).
380    pub sigma: Vec<f64>,
381    /// Number of output classes.
382    pub n_classes: usize,
383    /// Parent node indices.
384    pub parent_indices: Vec<usize>,
385}
386
387impl ConditionalLinear {
388    /// Create a new ConditionalLinear CPD.
389    pub fn new(
390        node_idx: usize,
391        w: Vec<Vec<f64>>,
392        b: Vec<f64>,
393        sigma: Vec<f64>,
394        n_classes: usize,
395        parent_indices: Vec<usize>,
396    ) -> Result<Self, StatsError> {
397        if w.len() != n_classes || b.len() != n_classes || sigma.len() != n_classes {
398            return Err(StatsError::InvalidInput(
399                "w, b, sigma must all have length n_classes".to_string(),
400            ));
401        }
402        Ok(Self {
403            node_idx,
404            w,
405            b,
406            sigma,
407            n_classes,
408            parent_indices,
409        })
410    }
411
412    /// Compute softmax probabilities.
413    pub fn softmax(&self, parent_values: &[f64]) -> Vec<f64> {
414        let logits: Vec<f64> = self
415            .w
416            .iter()
417            .zip(&self.b)
418            .map(|(wk, bk)| {
419                bk + wk
420                    .iter()
421                    .zip(parent_values)
422                    .map(|(wi, xi)| wi * xi)
423                    .sum::<f64>()
424            })
425            .collect();
426        let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
427        let exps: Vec<f64> = logits.iter().map(|l| (l - max_l).exp()).collect();
428        let sum: f64 = exps.iter().sum();
429        exps.iter().map(|e| e / sum).collect()
430    }
431}
432
433impl CPD for ConditionalLinear {
434    fn node(&self) -> usize {
435        self.node_idx
436    }
437
438    fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
439        if value >= self.n_classes {
440            return 0.0;
441        }
442        let pv: Vec<f64> = parent_values.iter().map(|&v| v as f64).collect();
443        let probs = self.softmax(&pv);
444        probs[value]
445    }
446
447    fn cardinality(&self) -> usize {
448        self.n_classes
449    }
450
451    fn parent_indices(&self) -> &[usize] {
452        &self.parent_indices
453    }
454}
455
456// ---------------------------------------------------------------------------
457// Helpers
458// ---------------------------------------------------------------------------
459
460/// Compute stride array for column-major ordering.
461/// stride[i] = product(card[i+1..])
462pub(crate) fn compute_strides(card: &[usize]) -> Vec<usize> {
463    let n = card.len();
464    let mut strides = vec![1usize; n];
465    for i in (0..n.saturating_sub(1)).rev() {
466        strides[i] = strides[i + 1] * card[i + 1];
467    }
468    strides
469}
470
471// ---------------------------------------------------------------------------
472// Unit tests
473// ---------------------------------------------------------------------------
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    fn rain_cpd() -> TabularCPD {
480        // P(Rain=0) = 0.8, P(Rain=1) = 0.2
481        TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.8, 0.2]]).unwrap()
482    }
483
484    fn wetgrass_cpd() -> TabularCPD {
485        // P(WG | Rain, Sprinkler) — 4 rows
486        TabularCPD::new(
487            2,
488            2,
489            vec![0, 1], // Rain, Sprinkler
490            vec![2, 2],
491            vec![
492                vec![0.99, 0.01], // R=0, S=0
493                vec![0.01, 0.99], // R=0, S=1
494                vec![0.01, 0.99], // R=1, S=0
495                vec![0.01, 0.99], // R=1, S=1
496            ],
497        )
498        .unwrap()
499    }
500
501    #[test]
502    fn test_tabular_no_parents() {
503        let cpd = rain_cpd();
504        assert!((cpd.prob(0, &[]) - 0.8).abs() < 1e-9);
505        assert!((cpd.prob(1, &[]) - 0.2).abs() < 1e-9);
506    }
507
508    #[test]
509    fn test_tabular_with_parents() {
510        let cpd = wetgrass_cpd();
511        // P(WG=1 | Rain=1, Spr=0) = 0.99
512        assert!((cpd.prob(1, &[1, 0]) - 0.99).abs() < 1e-9);
513        // P(WG=0 | Rain=0, Spr=0) = 0.99
514        assert!((cpd.prob(0, &[0, 0]) - 0.99).abs() < 1e-9);
515    }
516
517    #[test]
518    fn test_tabular_bad_sum() {
519        let res = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.5, 0.3]]);
520        assert!(res.is_err());
521    }
522
523    #[test]
524    fn test_gaussian_cpd() {
525        let cpd = GaussianCPD::new(0, 0.0, 1.0, vec![0.5], vec![1]).unwrap();
526        // Mean = 0 + 0.5 * 2.0 = 1.0; evaluate density at x=1.0
527        let d = cpd.density(1.0, &[2.0]);
528        let expected = 1.0 / (2.0 * PI).sqrt();
529        assert!((d - expected).abs() < 1e-9);
530    }
531
532    #[test]
533    fn test_mixture_cpd() {
534        let c1 = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.6, 0.4]]).unwrap();
535        let c2 = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.4, 0.6]]).unwrap();
536        let mix = MixtureCPD::new(0, vec![c1, c2], vec![0.5, 0.5]).unwrap();
537        // Expected: 0.5*0.6 + 0.5*0.4 = 0.5
538        assert!((mix.prob(0, &[]) - 0.5).abs() < 1e-9);
539    }
540
541    #[test]
542    fn test_conditional_linear() {
543        // Two classes, one parent
544        let cpd = ConditionalLinear::new(
545            0,
546            vec![vec![1.0], vec![-1.0]], // w
547            vec![0.0, 0.0],              // b
548            vec![1.0, 1.0],              // sigma
549            2,
550            vec![1],
551        )
552        .unwrap();
553        // parent_val = 0 → logits = [0, 0] → softmax = [0.5, 0.5]
554        assert!((cpd.prob(0, &[0]) - 0.5).abs() < 1e-9);
555    }
556
557    #[test]
558    fn test_strides() {
559        assert_eq!(compute_strides(&[2, 3]), vec![3, 1]);
560        assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
561        assert_eq!(compute_strides(&[]), Vec::<usize>::new());
562    }
563}