1use crate::StatsError;
11use std::f64::consts::PI;
12
13pub trait CPD: Send + Sync {
19 fn node(&self) -> usize;
21
22 fn prob(&self, value: usize, parent_values: &[usize]) -> f64;
28
29 fn cardinality(&self) -> usize;
31
32 fn parent_indices(&self) -> &[usize];
34
35 fn is_continuous(&self) -> bool {
37 false
38 }
39
40 fn log_prob(&self, value: usize, parent_values: &[usize]) -> f64 {
42 let p = self.prob(value, parent_values);
43 if p <= 0.0 {
44 f64::NEG_INFINITY
45 } else {
46 p.ln()
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
65pub struct TabularCPD {
66 pub node_idx: usize,
68 pub n_values: usize,
70 pub parent_card: Vec<usize>,
72 pub parent_indices: Vec<usize>,
74 pub table: Vec<Vec<f64>>,
77 strides: Vec<usize>,
79}
80
81impl TabularCPD {
82 pub fn new(
92 node_idx: usize,
93 n_values: usize,
94 parent_indices: Vec<usize>,
95 parent_card: Vec<usize>,
96 values: Vec<Vec<f64>>,
97 ) -> Result<Self, StatsError> {
98 if parent_indices.len() != parent_card.len() {
99 return Err(StatsError::InvalidInput(
100 "parent_indices and parent_card must have the same length".to_string(),
101 ));
102 }
103 let n_rows: usize = if parent_card.is_empty() {
104 1
105 } else {
106 parent_card.iter().product()
107 };
108 if values.len() != n_rows {
109 return Err(StatsError::InvalidInput(format!(
110 "Expected {n_rows} rows (product of parent cardinalities), got {}",
111 values.len()
112 )));
113 }
114 for (i, row) in values.iter().enumerate() {
115 if row.len() != n_values {
116 return Err(StatsError::InvalidInput(format!(
117 "Row {i} has {} values, expected {n_values}",
118 row.len()
119 )));
120 }
121 let sum: f64 = row.iter().sum();
122 if (sum - 1.0).abs() > 1e-6 {
123 return Err(StatsError::InvalidInput(format!(
124 "Row {i} does not sum to 1.0 (sum={sum:.6})"
125 )));
126 }
127 }
128 let strides = compute_strides(&parent_card);
130 Ok(Self {
131 node_idx,
132 n_values,
133 parent_card,
134 parent_indices,
135 table: values,
136 strides,
137 })
138 }
139
140 pub fn row_index(&self, parent_values: &[usize]) -> Result<usize, StatsError> {
142 if parent_values.len() != self.parent_card.len() {
143 return Err(StatsError::InvalidInput(format!(
144 "Expected {} parent values, got {}",
145 self.parent_card.len(),
146 parent_values.len()
147 )));
148 }
149 let mut row = 0usize;
150 for (i, &pv) in parent_values.iter().enumerate() {
151 if pv >= self.parent_card[i] {
152 return Err(StatsError::InvalidInput(format!(
153 "Parent {i} value {pv} out of range (card={})",
154 self.parent_card[i]
155 )));
156 }
157 row += pv * self.strides[i];
158 }
159 Ok(row)
160 }
161
162 pub fn distribution(&self, parent_values: &[usize]) -> Result<&[f64], StatsError> {
164 let row = self.row_index(parent_values)?;
165 Ok(&self.table[row])
166 }
167}
168
169impl CPD for TabularCPD {
170 fn node(&self) -> usize {
171 self.node_idx
172 }
173
174 fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
175 if value >= self.n_values {
176 return 0.0;
177 }
178 let row = match self.row_index(parent_values) {
179 Ok(r) => r,
180 Err(_) => return 0.0,
181 };
182 self.table[row][value]
183 }
184
185 fn cardinality(&self) -> usize {
186 self.n_values
187 }
188
189 fn parent_indices(&self) -> &[usize] {
190 &self.parent_indices
191 }
192}
193
194#[derive(Debug, Clone)]
202pub struct GaussianCPD {
203 pub node_idx: usize,
205 pub mu: f64,
207 pub sigma: f64,
209 pub beta: Vec<f64>,
211 pub parent_indices: Vec<usize>,
213}
214
215impl GaussianCPD {
216 pub fn new(
218 node_idx: usize,
219 mu: f64,
220 sigma: f64,
221 beta: Vec<f64>,
222 parent_indices: Vec<usize>,
223 ) -> Result<Self, StatsError> {
224 if sigma <= 0.0 {
225 return Err(StatsError::InvalidInput(format!(
226 "sigma must be positive, got {sigma}"
227 )));
228 }
229 if beta.len() != parent_indices.len() {
230 return Err(StatsError::InvalidInput(
231 "beta and parent_indices must have the same length".to_string(),
232 ));
233 }
234 Ok(Self {
235 node_idx,
236 mu,
237 sigma,
238 beta,
239 parent_indices,
240 })
241 }
242
243 pub fn conditional_mean(&self, parent_vals: &[f64]) -> f64 {
245 self.mu
246 + self
247 .beta
248 .iter()
249 .zip(parent_vals)
250 .map(|(b, v)| b * v)
251 .sum::<f64>()
252 }
253
254 pub fn density(&self, x: f64, parent_vals: &[f64]) -> f64 {
256 let mean = self.conditional_mean(parent_vals);
257 let z = (x - mean) / self.sigma;
258 (-0.5 * z * z).exp() / (self.sigma * (2.0 * PI).sqrt())
259 }
260}
261
262impl CPD for GaussianCPD {
263 fn node(&self) -> usize {
264 self.node_idx
265 }
266
267 fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
269 let pv: Vec<f64> = parent_values.iter().map(|&v| v as f64).collect();
270 self.density(value as f64, &pv)
271 }
272
273 fn cardinality(&self) -> usize {
274 0 }
276
277 fn parent_indices(&self) -> &[usize] {
278 &self.parent_indices
279 }
280
281 fn is_continuous(&self) -> bool {
282 true
283 }
284}
285
286#[derive(Debug, Clone)]
294pub struct MixtureCPD {
295 pub node_idx: usize,
297 pub components: Vec<TabularCPD>,
299 pub weights: Vec<f64>,
301}
302
303impl MixtureCPD {
304 pub fn new(
306 node_idx: usize,
307 components: Vec<TabularCPD>,
308 weights: Vec<f64>,
309 ) -> Result<Self, StatsError> {
310 if components.is_empty() {
311 return Err(StatsError::InvalidInput(
312 "MixtureCPD needs at least one component".to_string(),
313 ));
314 }
315 if components.len() != weights.len() {
316 return Err(StatsError::InvalidInput(
317 "components and weights must have the same length".to_string(),
318 ));
319 }
320 let wsum: f64 = weights.iter().sum();
321 if (wsum - 1.0).abs() > 1e-6 {
322 return Err(StatsError::InvalidInput(format!(
323 "weights must sum to 1.0 (got {wsum:.6})"
324 )));
325 }
326 for w in &weights {
327 if *w < 0.0 {
328 return Err(StatsError::InvalidInput(
329 "weights must be non-negative".to_string(),
330 ));
331 }
332 }
333 Ok(Self {
334 node_idx,
335 components,
336 weights,
337 })
338 }
339}
340
341impl CPD for MixtureCPD {
342 fn node(&self) -> usize {
343 self.node_idx
344 }
345
346 fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
347 self.components
348 .iter()
349 .zip(&self.weights)
350 .map(|(c, w)| w * c.prob(value, parent_values))
351 .sum()
352 }
353
354 fn cardinality(&self) -> usize {
355 self.components[0].cardinality()
356 }
357
358 fn parent_indices(&self) -> &[usize] {
359 self.components[0].parent_indices()
360 }
361}
362
363#[derive(Debug, Clone)]
372pub struct ConditionalLinear {
373 pub node_idx: usize,
375 pub w: Vec<Vec<f64>>,
377 pub b: Vec<f64>,
379 pub sigma: Vec<f64>,
381 pub n_classes: usize,
383 pub parent_indices: Vec<usize>,
385}
386
387impl ConditionalLinear {
388 pub fn new(
390 node_idx: usize,
391 w: Vec<Vec<f64>>,
392 b: Vec<f64>,
393 sigma: Vec<f64>,
394 n_classes: usize,
395 parent_indices: Vec<usize>,
396 ) -> Result<Self, StatsError> {
397 if w.len() != n_classes || b.len() != n_classes || sigma.len() != n_classes {
398 return Err(StatsError::InvalidInput(
399 "w, b, sigma must all have length n_classes".to_string(),
400 ));
401 }
402 Ok(Self {
403 node_idx,
404 w,
405 b,
406 sigma,
407 n_classes,
408 parent_indices,
409 })
410 }
411
412 pub fn softmax(&self, parent_values: &[f64]) -> Vec<f64> {
414 let logits: Vec<f64> = self
415 .w
416 .iter()
417 .zip(&self.b)
418 .map(|(wk, bk)| {
419 bk + wk
420 .iter()
421 .zip(parent_values)
422 .map(|(wi, xi)| wi * xi)
423 .sum::<f64>()
424 })
425 .collect();
426 let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
427 let exps: Vec<f64> = logits.iter().map(|l| (l - max_l).exp()).collect();
428 let sum: f64 = exps.iter().sum();
429 exps.iter().map(|e| e / sum).collect()
430 }
431}
432
433impl CPD for ConditionalLinear {
434 fn node(&self) -> usize {
435 self.node_idx
436 }
437
438 fn prob(&self, value: usize, parent_values: &[usize]) -> f64 {
439 if value >= self.n_classes {
440 return 0.0;
441 }
442 let pv: Vec<f64> = parent_values.iter().map(|&v| v as f64).collect();
443 let probs = self.softmax(&pv);
444 probs[value]
445 }
446
447 fn cardinality(&self) -> usize {
448 self.n_classes
449 }
450
451 fn parent_indices(&self) -> &[usize] {
452 &self.parent_indices
453 }
454}
455
456pub(crate) fn compute_strides(card: &[usize]) -> Vec<usize> {
463 let n = card.len();
464 let mut strides = vec![1usize; n];
465 for i in (0..n.saturating_sub(1)).rev() {
466 strides[i] = strides[i + 1] * card[i + 1];
467 }
468 strides
469}
470
471#[cfg(test)]
476mod tests {
477 use super::*;
478
479 fn rain_cpd() -> TabularCPD {
480 TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.8, 0.2]]).unwrap()
482 }
483
484 fn wetgrass_cpd() -> TabularCPD {
485 TabularCPD::new(
487 2,
488 2,
489 vec![0, 1], vec![2, 2],
491 vec![
492 vec![0.99, 0.01], vec![0.01, 0.99], vec![0.01, 0.99], vec![0.01, 0.99], ],
497 )
498 .unwrap()
499 }
500
501 #[test]
502 fn test_tabular_no_parents() {
503 let cpd = rain_cpd();
504 assert!((cpd.prob(0, &[]) - 0.8).abs() < 1e-9);
505 assert!((cpd.prob(1, &[]) - 0.2).abs() < 1e-9);
506 }
507
508 #[test]
509 fn test_tabular_with_parents() {
510 let cpd = wetgrass_cpd();
511 assert!((cpd.prob(1, &[1, 0]) - 0.99).abs() < 1e-9);
513 assert!((cpd.prob(0, &[0, 0]) - 0.99).abs() < 1e-9);
515 }
516
517 #[test]
518 fn test_tabular_bad_sum() {
519 let res = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.5, 0.3]]);
520 assert!(res.is_err());
521 }
522
523 #[test]
524 fn test_gaussian_cpd() {
525 let cpd = GaussianCPD::new(0, 0.0, 1.0, vec![0.5], vec![1]).unwrap();
526 let d = cpd.density(1.0, &[2.0]);
528 let expected = 1.0 / (2.0 * PI).sqrt();
529 assert!((d - expected).abs() < 1e-9);
530 }
531
532 #[test]
533 fn test_mixture_cpd() {
534 let c1 = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.6, 0.4]]).unwrap();
535 let c2 = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.4, 0.6]]).unwrap();
536 let mix = MixtureCPD::new(0, vec![c1, c2], vec![0.5, 0.5]).unwrap();
537 assert!((mix.prob(0, &[]) - 0.5).abs() < 1e-9);
539 }
540
541 #[test]
542 fn test_conditional_linear() {
543 let cpd = ConditionalLinear::new(
545 0,
546 vec![vec![1.0], vec![-1.0]], vec![0.0, 0.0], vec![1.0, 1.0], 2,
550 vec![1],
551 )
552 .unwrap();
553 assert!((cpd.prob(0, &[0]) - 0.5).abs() < 1e-9);
555 }
556
557 #[test]
558 fn test_strides() {
559 assert_eq!(compute_strides(&[2, 3]), vec![3, 1]);
560 assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
561 assert_eq!(compute_strides(&[]), Vec::<usize>::new());
562 }
563}