uncertain_rs/
uncertain.rs1use crate::computation::{ComputationNode, SampleContext};
2use crate::operations::Arithmetic;
3use crate::traits::Shareable;
4use std::sync::Arc;
5
6#[derive(Clone)]
13pub struct Uncertain<T> {
14 pub(crate) id: uuid::Uuid,
16 pub sample_fn: Arc<dyn Fn() -> T + Send + Sync>,
18 pub(crate) node: ComputationNode<T>,
20}
21
22impl<T> Uncertain<T>
23where
24 T: Shareable,
25{
26 pub fn new<F>(sampler: F) -> Self
38 where
39 F: Fn() -> T + Send + Sync + 'static,
40 {
41 let sampler = Arc::new(sampler);
42 let id = uuid::Uuid::new_v4();
43 let node = ComputationNode::Leaf {
44 id,
45 sample: sampler.clone(),
46 };
47
48 Self {
49 id,
50 sample_fn: sampler,
51 node,
52 }
53 }
54
55 pub(crate) fn with_node(node: ComputationNode<T>) -> Self
57 where
58 T: Arithmetic,
59 {
60 let node_clone = node.clone();
61 let sample_fn = Arc::new(move || {
62 let mut context = SampleContext::new();
63 node_clone.evaluate_conditional_with_arithmetic(&mut context)
64 });
65 let id = uuid::Uuid::new_v4();
66
67 Self {
68 id,
69 sample_fn,
70 node,
71 }
72 }
73
74 #[must_use]
78 pub fn id(&self) -> uuid::Uuid {
79 self.id
80 }
81
82 #[must_use]
93 pub fn sample(&self) -> T {
94 (self.sample_fn)()
95 }
96
97 #[must_use]
107 pub fn map<U, F>(&self, transform: F) -> Uncertain<U>
108 where
109 U: Shareable,
110 F: Fn(T) -> U + Send + Sync + 'static,
111 {
112 let sample_fn = self.sample_fn.clone();
113 Uncertain::new(move || transform(sample_fn()))
114 }
115
116 #[must_use]
126 pub fn flat_map<U, F>(&self, transform: F) -> Uncertain<U>
127 where
128 U: Shareable,
129 F: Fn(T) -> Uncertain<U> + Send + Sync + 'static,
130 {
131 let sample_fn = self.sample_fn.clone();
132 Uncertain::new(move || transform(sample_fn()).sample())
133 }
134
135 #[must_use]
149 pub fn filter<F>(&self, predicate: F) -> Uncertain<T>
150 where
151 F: Fn(&T) -> bool + Send + Sync + 'static,
152 {
153 let sample_fn = self.sample_fn.clone();
154 Uncertain::new(move || {
155 loop {
156 let value = sample_fn();
157 if predicate(&value) {
158 return value;
159 }
160 }
161 })
162 }
163
164 #[must_use = "iterators are lazy and do nothing unless consumed"]
174 pub fn samples(&self) -> impl Iterator<Item = T> + '_ {
175 std::iter::repeat_with(|| self.sample())
176 }
177
178 #[must_use]
188 pub fn take_samples(&self, count: usize) -> Vec<T> {
189 self.samples().take(count).collect()
190 }
191}
192
193impl Uncertain<f64> {
194 #[must_use]
207 pub fn take_samples_cached(&self, count: usize) -> Vec<f64> {
208 crate::cache::dist_cache()
209 .get_or_compute_samples(self.id, count, || self.samples().take(count).collect())
210 }
211}
212
213impl<T> Uncertain<T>
214where
215 T: Shareable + PartialOrd,
216{
217 #[must_use]
219 pub fn less_than(&self, other: &Self) -> Uncertain<bool> {
220 let self_fn = self.sample_fn.clone();
221 let other_fn = other.sample_fn.clone();
222
223 Uncertain::new(move || {
224 let a = self_fn();
225 let b = other_fn();
226 a < b
227 })
228 }
229
230 #[must_use]
232 pub fn greater_than(&self, other: &Self) -> Uncertain<bool> {
233 let self_fn = self.sample_fn.clone();
234 let other_fn = other.sample_fn.clone();
235
236 Uncertain::new(move || {
237 let a = self_fn();
238 let b = other_fn();
239 a > b
240 })
241 }
242}
243
244impl<T> Uncertain<T>
245where
246 T: Shareable + PartialOrd + PartialEq + Copy,
247{
248 #[must_use]
262 pub fn gt(&self, threshold: T) -> Uncertain<bool> {
263 let sample_fn = self.sample_fn.clone();
264 Uncertain::new(move || sample_fn() > threshold)
265 }
266
267 #[must_use]
269 pub fn lt(&self, threshold: T) -> Uncertain<bool> {
270 let sample_fn = self.sample_fn.clone();
271 Uncertain::new(move || sample_fn() < threshold)
272 }
273
274 #[must_use]
276 pub fn ge(&self, threshold: T) -> Uncertain<bool> {
277 let sample_fn = self.sample_fn.clone();
278 Uncertain::new(move || sample_fn() >= threshold)
279 }
280
281 #[must_use]
283 pub fn le(&self, threshold: T) -> Uncertain<bool> {
284 let sample_fn = self.sample_fn.clone();
285 Uncertain::new(move || sample_fn() <= threshold)
286 }
287
288 #[must_use]
293 pub fn eq_value(&self, threshold: T) -> Uncertain<bool> {
294 let sample_fn = self.sample_fn.clone();
295 Uncertain::new(move || sample_fn() == threshold)
296 }
297
298 #[must_use]
300 pub fn ne_value(&self, threshold: T) -> Uncertain<bool> {
301 let sample_fn = self.sample_fn.clone();
302 Uncertain::new(move || sample_fn() != threshold)
303 }
304}
305
306impl<T> std::cmp::PartialEq for Uncertain<T>
307where
308 T: Shareable + PartialEq,
309{
310 fn eq(&self, other: &Self) -> bool {
311 let sample_a = self.sample();
313 let sample_b = other.sample();
314 sample_a == sample_b
315 }
316}
317
318impl<T> std::cmp::PartialOrd for Uncertain<T>
319where
320 T: Shareable + PartialOrd,
321{
322 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
323 let sample_a = self.sample();
324 let sample_b = other.sample();
325 sample_a.partial_cmp(&sample_b)
326 }
327
328 fn lt(&self, other: &Self) -> bool {
329 let sample_a = self.sample();
330 let sample_b = other.sample();
331 sample_a < sample_b
332 }
333
334 fn gt(&self, other: &Self) -> bool {
335 let sample_a = self.sample();
336 let sample_b = other.sample();
337 sample_a > sample_b
338 }
339}
340
341impl<T> std::fmt::Debug for Uncertain<T>
342where
343 T: Shareable + std::fmt::Debug,
344{
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 f.debug_struct("Uncertain")
347 .field("sample", &self.sample())
348 .finish()
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_new_uncertain() {
358 let uncertain = Uncertain::new(|| 42.0_f64);
359 assert!((uncertain.sample() - 42.0_f64).abs() < f64::EPSILON);
360 }
361
362 #[test]
363 fn test_sample() {
364 let uncertain = Uncertain::new(|| std::f64::consts::PI);
365 assert!((uncertain.sample() - std::f64::consts::PI).abs() < f64::EPSILON);
366 assert!((uncertain.sample() - std::f64::consts::PI).abs() < f64::EPSILON); }
368
369 #[test]
370 fn test_map() {
371 let uncertain = Uncertain::new(|| 5.0_f64);
372 let mapped = uncertain.map(|x| x * 2.0);
373 assert!((mapped.sample() - 10.0_f64).abs() < f64::EPSILON);
374 }
375
376 #[test]
377 #[allow(clippy::cast_possible_truncation)]
378 fn test_map_type_conversion() {
379 let uncertain = Uncertain::new(|| 5.0_f64);
380 let mapped = uncertain.map(|x| x as i32);
381 assert_eq!(mapped.sample(), 5);
382 }
383
384 #[test]
385 fn test_flat_map() {
386 let base = Uncertain::new(|| 3.0_f64);
387 let dependent = base.flat_map(|x| Uncertain::new(move || x + 1.0));
388 assert!((dependent.sample() - 4.0_f64).abs() < f64::EPSILON);
389 }
390
391 #[test]
392 fn test_flat_map_chain() {
393 let base = Uncertain::new(|| 2.0_f64);
394 let chained = base
395 .flat_map(|x| Uncertain::new(move || x * 2.0))
396 .flat_map(|x| Uncertain::new(move || x + 1.0));
397 assert!((chained.sample() - 5.0_f64).abs() < f64::EPSILON);
398 }
399
400 #[test]
401 fn test_filter() {
402 let uncertain = Uncertain::new(|| 10.0);
403 let filtered = uncertain.filter(|&x| x > 5.0);
404 assert!(filtered.sample() > 5.0);
405 }
406
407 #[test]
408 fn test_filter_rejection_sampling() {
409 use std::sync::Arc;
410 use std::sync::atomic::{AtomicI32, Ordering};
411 let counter = Arc::new(AtomicI32::new(0));
412 let counter_clone = counter.clone();
413 let uncertain = Uncertain::new(move || {
414 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
415 if count < 3 { 1.0 } else { 10.0 }
416 });
417 let filtered = uncertain.filter(|&x| x > 5.0);
418 assert!(filtered.sample() > 5.0);
419 }
420
421 #[test]
422 fn test_samples_iterator() {
423 let uncertain = Uncertain::new(|| 42.0);
424 let samples: Vec<f64> = uncertain.samples().take(5).collect();
425 assert_eq!(samples, vec![42.0, 42.0, 42.0, 42.0, 42.0]);
426 }
427
428 #[test]
429 fn test_take_samples() {
430 let uncertain = Uncertain::new(|| 7.0);
431 let samples = uncertain.take_samples(3);
432 assert_eq!(samples, vec![7.0, 7.0, 7.0]);
433 }
434
435 #[test]
436 fn test_take_samples_empty() {
437 let uncertain = Uncertain::new(|| 1.0);
438 let samples = uncertain.take_samples(0);
439 assert!(samples.is_empty());
440 }
441
442 #[test]
443 fn test_less_than() {
444 let smaller = Uncertain::new(|| 1.0);
445 let larger = Uncertain::new(|| 2.0);
446 let comparison = smaller.less_than(&larger);
447 assert!(comparison.sample());
448 }
449
450 #[test]
451 fn test_less_than_false() {
452 let larger = Uncertain::new(|| 2.0);
453 let smaller = Uncertain::new(|| 1.0);
454 let comparison = larger.less_than(&smaller);
455 assert!(!comparison.sample());
456 }
457
458 #[test]
459 fn test_greater_than() {
460 let larger = Uncertain::new(|| 2.0);
461 let smaller = Uncertain::new(|| 1.0);
462 let comparison = larger.greater_than(&smaller);
463 assert!(comparison.sample());
464 }
465
466 #[test]
467 fn test_greater_than_false() {
468 let smaller = Uncertain::new(|| 1.0);
469 let larger = Uncertain::new(|| 2.0);
470 let comparison = smaller.greater_than(&larger);
471 assert!(!comparison.sample());
472 }
473
474 #[test]
475 fn test_partial_eq() {
476 let a = Uncertain::new(|| 5.0);
477 let b = Uncertain::new(|| 5.0);
478 let c = Uncertain::new(|| 10.0);
479
480 assert_eq!(a, b);
481 assert_ne!(a, c);
482 }
483
484 #[test]
485 fn test_partial_ord() {
486 let smaller = Uncertain::new(|| 1.0);
487 let larger = Uncertain::new(|| 2.0);
488
489 assert!(smaller < larger);
490 assert!(larger > smaller);
491 assert!(smaller.partial_cmp(&larger).is_some());
492 }
493
494 #[test]
495 fn test_partial_ord_equal() {
496 let a = Uncertain::new(|| 5.0);
497 let b = Uncertain::new(|| 5.0);
498
499 assert!(a.partial_cmp(&b).is_some());
500 assert!(b.partial_cmp(&a).is_some());
501 }
502
503 #[test]
504 fn test_debug_formatting() {
505 let uncertain = Uncertain::new(|| 42);
506 let debug_str = format!("{uncertain:?}");
507 assert!(debug_str.contains("Uncertain"));
508 assert!(debug_str.contains("42"));
509 }
510
511 #[test]
512 fn test_clone() {
513 let original = Uncertain::new(|| 123.0_f64);
514 let cloned = original.clone();
515
516 assert!((original.sample() - cloned.sample()).abs() < f64::EPSILON);
517 assert!((original.sample() - 123.0_f64).abs() < f64::EPSILON);
518 assert!((cloned.sample() - 123.0_f64).abs() < f64::EPSILON);
519 }
520
521 #[test]
522 fn test_with_random_sampler() {
523 use rand::random;
524 let uncertain = Uncertain::new(random::<f64>);
525
526 let sample1 = uncertain.sample();
528 let sample2 = uncertain.sample();
529 assert!((0.0..=1.0).contains(&sample1));
531 assert!((0.0..=1.0).contains(&sample2));
532 }
533
534 #[test]
535 fn test_map_preserves_uncertainty() {
536 use rand::random;
537 let base = Uncertain::new(random::<f64>);
538 let transformed = base.map(|x| x * 100.0);
539
540 let sample = transformed.sample();
541 assert!((0.0..=100.0).contains(&sample));
542 }
543
544 #[test]
545 fn test_gt_method_api() {
546 let speed = Uncertain::new(|| 65.0);
547 let speeding_evidence = speed.gt(60.0);
548 assert!(speeding_evidence.sample()); }
550
551 #[test]
552 fn test_lt_method_api() {
553 let temperature = Uncertain::new(|| -5.0);
554 let freezing_evidence = temperature.lt(0.0);
555 assert!(freezing_evidence.sample()); }
557
558 #[test]
559 fn test_ge_method_api() {
560 let value = Uncertain::new(|| 10.0);
561 let evidence = value.ge(10.0);
562 assert!(evidence.sample()); }
564
565 #[test]
566 fn test_le_method_api() {
567 let value = Uncertain::new(|| 5.0);
568 let evidence = value.le(10.0);
569 assert!(evidence.sample()); }
571
572 #[test]
573 fn test_eq_value_method_api() {
574 let value = Uncertain::new(|| 42);
575 let evidence = value.eq_value(42);
576 assert!(evidence.sample()); }
578
579 #[test]
580 fn test_ne_value_method_api() {
581 let value = Uncertain::new(|| 42);
582 let evidence = value.ne_value(0);
583 assert!(evidence.sample()); }
585
586 #[test]
587 fn test_readme_example_api() {
588 let speed = Uncertain::normal(55.2, 5.0);
590 let speeding_evidence = speed.gt(60.0);
591
592 let _result = speeding_evidence.probability_exceeds(0.95);
594
595 let high_speed = Uncertain::point(70.0);
597 let high_speed_evidence = high_speed.gt(60.0);
598 assert!(high_speed_evidence.probability_exceeds(0.95));
599 }
600}