uncertain_rs/
uncertain.rs

1use crate::computation::{ComputationNode, SampleContext};
2use crate::operations::Arithmetic;
3use crate::traits::Shareable;
4use std::sync::Arc;
5
6/// A type that represents uncertain data as a probability distribution
7/// using sampling-based computation with conditional semantics.
8///
9/// `Uncertain` provides a way to work with probabilistic values
10/// by representing them as sampling functions with a computation graph
11/// for lazy evaluation and proper uncertainty-aware conditionals.
12#[derive(Clone)]
13pub struct Uncertain<T> {
14    /// Unique identifier for caching purposes
15    pub(crate) id: uuid::Uuid,
16    /// The sampling function that generates values from this distribution
17    pub sample_fn: Arc<dyn Fn() -> T + Send + Sync>,
18    /// The computation graph node for lazy evaluation
19    pub(crate) node: ComputationNode<T>,
20}
21
22impl<T> Uncertain<T>
23where
24    T: Shareable,
25{
26    /// Creates an uncertain value with the given sampling function.
27    ///
28    /// # Example
29    /// ```rust
30    /// use uncertain_rs::Uncertain;
31    ///
32    /// let custom = Uncertain::new(|| {
33    ///     // Your custom sampling logic
34    ///     rand::random::<f64>() * 10.0
35    /// });
36    /// ```
37    pub fn new<F>(sampler: F) -> Self
38    where
39        F: Fn() -> T + Send + Sync + 'static,
40    {
41        let sampler = Arc::new(sampler);
42        let id = uuid::Uuid::new_v4();
43        let node = ComputationNode::Leaf {
44            id,
45            sample: sampler.clone(),
46        };
47
48        Self {
49            id,
50            sample_fn: sampler,
51            node,
52        }
53    }
54
55    /// Internal constructor with computation node for building computation graphs
56    pub(crate) fn with_node(node: ComputationNode<T>) -> Self
57    where
58        T: Arithmetic,
59    {
60        let node_clone = node.clone();
61        let sample_fn = Arc::new(move || {
62            let mut context = SampleContext::new();
63            node_clone.evaluate_conditional_with_arithmetic(&mut context)
64        });
65        let id = uuid::Uuid::new_v4();
66
67        Self {
68            id,
69            sample_fn,
70            node,
71        }
72    }
73
74    /// Get the unique identifier for this uncertain value
75    ///
76    /// This is primarily used for caching purposes.
77    #[must_use]
78    pub fn id(&self) -> uuid::Uuid {
79        self.id
80    }
81
82    /// Generate a sample from this distribution
83    ///
84    /// # Example
85    /// ```rust
86    /// use uncertain_rs::Uncertain;
87    ///
88    /// let normal = Uncertain::normal(0.0, 1.0);
89    /// let sample = normal.sample();
90    /// println!("Sample: {}", sample);
91    /// ```
92    #[must_use]
93    pub fn sample(&self) -> T {
94        (self.sample_fn)()
95    }
96
97    /// Transforms an uncertain value by applying a function to each sample.
98    ///
99    /// # Example
100    /// ```rust
101    /// use uncertain_rs::Uncertain;
102    ///
103    /// let celsius = Uncertain::normal(20.0, 2.0);
104    /// let fahrenheit = celsius.map(|c| c * 9.0/5.0 + 32.0);
105    /// ```
106    #[must_use]
107    pub fn map<U, F>(&self, transform: F) -> Uncertain<U>
108    where
109        U: Shareable,
110        F: Fn(T) -> U + Send + Sync + 'static,
111    {
112        let sample_fn = self.sample_fn.clone();
113        Uncertain::new(move || transform(sample_fn()))
114    }
115
116    /// Transforms an uncertain value by applying a function that returns another uncertain value.
117    ///
118    /// # Example
119    /// ```rust
120    /// use uncertain_rs::Uncertain;
121    ///
122    /// let base = Uncertain::normal(5.0, 1.0);
123    /// let dependent = base.flat_map(|b| Uncertain::normal(b, 0.5));
124    /// ```
125    #[must_use]
126    pub fn flat_map<U, F>(&self, transform: F) -> Uncertain<U>
127    where
128        U: Shareable,
129        F: Fn(T) -> Uncertain<U> + Send + Sync + 'static,
130    {
131        let sample_fn = self.sample_fn.clone();
132        Uncertain::new(move || transform(sample_fn()).sample())
133    }
134
135    /// Filters samples using rejection sampling.
136    ///
137    /// Only samples that satisfy the predicate are accepted.
138    /// This method will keep sampling until a valid sample is found,
139    /// so ensure the predicate has a reasonable acceptance rate.
140    ///
141    /// # Example
142    /// ```rust
143    /// use uncertain_rs::Uncertain;
144    ///
145    /// let normal = Uncertain::normal(0.0, 1.0);
146    /// let positive_only = normal.filter(|&x| x > 0.0);
147    /// ```
148    #[must_use]
149    pub fn filter<F>(&self, predicate: F) -> Uncertain<T>
150    where
151        F: Fn(&T) -> bool + Send + Sync + 'static,
152    {
153        let sample_fn = self.sample_fn.clone();
154        Uncertain::new(move || {
155            loop {
156                let value = sample_fn();
157                if predicate(&value) {
158                    return value;
159                }
160            }
161        })
162    }
163
164    /// Generate an iterator of samples
165    ///
166    /// # Example
167    /// ```rust
168    /// use uncertain_rs::Uncertain;
169    ///
170    /// let normal = Uncertain::normal(0.0, 1.0);
171    /// let first_10: Vec<f64> = normal.samples().take(10).collect();
172    /// ```
173    #[must_use = "iterators are lazy and do nothing unless consumed"]
174    pub fn samples(&self) -> impl Iterator<Item = T> + '_ {
175        std::iter::repeat_with(|| self.sample())
176    }
177
178    /// Take a specific number of samples
179    ///
180    /// # Example
181    /// ```rust
182    /// use uncertain_rs::Uncertain;
183    ///
184    /// let uniform = Uncertain::uniform(0.0, 1.0);
185    /// let samples = uniform.take_samples(1000);
186    /// ```
187    #[must_use]
188    pub fn take_samples(&self, count: usize) -> Vec<T> {
189        self.samples().take(count).collect()
190    }
191}
192
193impl Uncertain<f64> {
194    /// Take samples with caching for better performance on repeated requests
195    ///
196    /// This is especially useful for expensive computations that might be called
197    /// multiple times with the same sample count.
198    ///
199    /// # Example
200    /// ```rust
201    /// use uncertain_rs::Uncertain;
202    ///
203    /// let gamma = Uncertain::gamma(2.0, 1.0);
204    /// let samples = gamma.take_samples_cached(1000); // Cached for reuse
205    /// ```
206    #[must_use]
207    pub fn take_samples_cached(&self, count: usize) -> Vec<f64> {
208        crate::cache::dist_cache()
209            .get_or_compute_samples(self.id, count, || self.samples().take(count).collect())
210    }
211}
212
213impl<T> Uncertain<T>
214where
215    T: Shareable + PartialOrd,
216{
217    /// Compare this uncertain value with another, returning an uncertain boolean
218    #[must_use]
219    pub fn less_than(&self, other: &Self) -> Uncertain<bool> {
220        let self_fn = self.sample_fn.clone();
221        let other_fn = other.sample_fn.clone();
222
223        Uncertain::new(move || {
224            let a = self_fn();
225            let b = other_fn();
226            a < b
227        })
228    }
229
230    /// Compare this uncertain value with another, returning an uncertain boolean
231    #[must_use]
232    pub fn greater_than(&self, other: &Self) -> Uncertain<bool> {
233        let self_fn = self.sample_fn.clone();
234        let other_fn = other.sample_fn.clone();
235
236        Uncertain::new(move || {
237            let a = self_fn();
238            let b = other_fn();
239            a > b
240        })
241    }
242}
243
244impl<T> Uncertain<T>
245where
246    T: Shareable + PartialOrd + PartialEq + Copy,
247{
248    /// Returns uncertain boolean evidence that this value is greater than threshold
249    ///
250    /// # Example
251    /// ```rust
252    /// use uncertain_rs::Uncertain;
253    ///
254    /// let speed = Uncertain::normal(55.2, 5.0);
255    /// let speeding_evidence = speed.gt(60.0);
256    ///
257    /// if speeding_evidence.probability_exceeds(0.95) {
258    ///     println!("Issue speeding ticket");
259    /// }
260    /// ```
261    #[must_use]
262    pub fn gt(&self, threshold: T) -> Uncertain<bool> {
263        let sample_fn = self.sample_fn.clone();
264        Uncertain::new(move || sample_fn() > threshold)
265    }
266
267    /// Returns uncertain boolean evidence that this value is less than threshold
268    #[must_use]
269    pub fn lt(&self, threshold: T) -> Uncertain<bool> {
270        let sample_fn = self.sample_fn.clone();
271        Uncertain::new(move || sample_fn() < threshold)
272    }
273
274    /// Returns uncertain boolean evidence that this value is greater than or equal to threshold
275    #[must_use]
276    pub fn ge(&self, threshold: T) -> Uncertain<bool> {
277        let sample_fn = self.sample_fn.clone();
278        Uncertain::new(move || sample_fn() >= threshold)
279    }
280
281    /// Returns uncertain boolean evidence that this value is less than or equal to threshold
282    #[must_use]
283    pub fn le(&self, threshold: T) -> Uncertain<bool> {
284        let sample_fn = self.sample_fn.clone();
285        Uncertain::new(move || sample_fn() <= threshold)
286    }
287
288    /// Returns uncertain boolean evidence that this value equals threshold
289    ///
290    /// Note: For floating point types, exact equality is rarely meaningful.
291    /// Consider using range-based comparisons instead.
292    #[must_use]
293    pub fn eq_value(&self, threshold: T) -> Uncertain<bool> {
294        let sample_fn = self.sample_fn.clone();
295        Uncertain::new(move || sample_fn() == threshold)
296    }
297
298    /// Returns uncertain boolean evidence that this value does not equal threshold
299    #[must_use]
300    pub fn ne_value(&self, threshold: T) -> Uncertain<bool> {
301        let sample_fn = self.sample_fn.clone();
302        Uncertain::new(move || sample_fn() != threshold)
303    }
304}
305
306impl<T> std::cmp::PartialEq for Uncertain<T>
307where
308    T: Shareable + PartialEq,
309{
310    fn eq(&self, other: &Self) -> bool {
311        // This is a fallback for direct equality testing
312        let sample_a = self.sample();
313        let sample_b = other.sample();
314        sample_a == sample_b
315    }
316}
317
318impl<T> std::cmp::PartialOrd for Uncertain<T>
319where
320    T: Shareable + PartialOrd,
321{
322    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
323        let sample_a = self.sample();
324        let sample_b = other.sample();
325        sample_a.partial_cmp(&sample_b)
326    }
327
328    fn lt(&self, other: &Self) -> bool {
329        let sample_a = self.sample();
330        let sample_b = other.sample();
331        sample_a < sample_b
332    }
333
334    fn gt(&self, other: &Self) -> bool {
335        let sample_a = self.sample();
336        let sample_b = other.sample();
337        sample_a > sample_b
338    }
339}
340
341impl<T> std::fmt::Debug for Uncertain<T>
342where
343    T: Shareable + std::fmt::Debug,
344{
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        f.debug_struct("Uncertain")
347            .field("sample", &self.sample())
348            .finish()
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_new_uncertain() {
358        let uncertain = Uncertain::new(|| 42.0_f64);
359        assert!((uncertain.sample() - 42.0_f64).abs() < f64::EPSILON);
360    }
361
362    #[test]
363    fn test_sample() {
364        let uncertain = Uncertain::new(|| std::f64::consts::PI);
365        assert!((uncertain.sample() - std::f64::consts::PI).abs() < f64::EPSILON);
366        assert!((uncertain.sample() - std::f64::consts::PI).abs() < f64::EPSILON); // Should be consistent for deterministic sampler
367    }
368
369    #[test]
370    fn test_map() {
371        let uncertain = Uncertain::new(|| 5.0_f64);
372        let mapped = uncertain.map(|x| x * 2.0);
373        assert!((mapped.sample() - 10.0_f64).abs() < f64::EPSILON);
374    }
375
376    #[test]
377    #[allow(clippy::cast_possible_truncation)]
378    fn test_map_type_conversion() {
379        let uncertain = Uncertain::new(|| 5.0_f64);
380        let mapped = uncertain.map(|x| x as i32);
381        assert_eq!(mapped.sample(), 5);
382    }
383
384    #[test]
385    fn test_flat_map() {
386        let base = Uncertain::new(|| 3.0_f64);
387        let dependent = base.flat_map(|x| Uncertain::new(move || x + 1.0));
388        assert!((dependent.sample() - 4.0_f64).abs() < f64::EPSILON);
389    }
390
391    #[test]
392    fn test_flat_map_chain() {
393        let base = Uncertain::new(|| 2.0_f64);
394        let chained = base
395            .flat_map(|x| Uncertain::new(move || x * 2.0))
396            .flat_map(|x| Uncertain::new(move || x + 1.0));
397        assert!((chained.sample() - 5.0_f64).abs() < f64::EPSILON);
398    }
399
400    #[test]
401    fn test_filter() {
402        let uncertain = Uncertain::new(|| 10.0);
403        let filtered = uncertain.filter(|&x| x > 5.0);
404        assert!(filtered.sample() > 5.0);
405    }
406
407    #[test]
408    fn test_filter_rejection_sampling() {
409        use std::sync::Arc;
410        use std::sync::atomic::{AtomicI32, Ordering};
411        let counter = Arc::new(AtomicI32::new(0));
412        let counter_clone = counter.clone();
413        let uncertain = Uncertain::new(move || {
414            let count = counter_clone.fetch_add(1, Ordering::SeqCst);
415            if count < 3 { 1.0 } else { 10.0 }
416        });
417        let filtered = uncertain.filter(|&x| x > 5.0);
418        assert!(filtered.sample() > 5.0);
419    }
420
421    #[test]
422    fn test_samples_iterator() {
423        let uncertain = Uncertain::new(|| 42.0);
424        let samples: Vec<f64> = uncertain.samples().take(5).collect();
425        assert_eq!(samples, vec![42.0, 42.0, 42.0, 42.0, 42.0]);
426    }
427
428    #[test]
429    fn test_take_samples() {
430        let uncertain = Uncertain::new(|| 7.0);
431        let samples = uncertain.take_samples(3);
432        assert_eq!(samples, vec![7.0, 7.0, 7.0]);
433    }
434
435    #[test]
436    fn test_take_samples_empty() {
437        let uncertain = Uncertain::new(|| 1.0);
438        let samples = uncertain.take_samples(0);
439        assert!(samples.is_empty());
440    }
441
442    #[test]
443    fn test_less_than() {
444        let smaller = Uncertain::new(|| 1.0);
445        let larger = Uncertain::new(|| 2.0);
446        let comparison = smaller.less_than(&larger);
447        assert!(comparison.sample());
448    }
449
450    #[test]
451    fn test_less_than_false() {
452        let larger = Uncertain::new(|| 2.0);
453        let smaller = Uncertain::new(|| 1.0);
454        let comparison = larger.less_than(&smaller);
455        assert!(!comparison.sample());
456    }
457
458    #[test]
459    fn test_greater_than() {
460        let larger = Uncertain::new(|| 2.0);
461        let smaller = Uncertain::new(|| 1.0);
462        let comparison = larger.greater_than(&smaller);
463        assert!(comparison.sample());
464    }
465
466    #[test]
467    fn test_greater_than_false() {
468        let smaller = Uncertain::new(|| 1.0);
469        let larger = Uncertain::new(|| 2.0);
470        let comparison = smaller.greater_than(&larger);
471        assert!(!comparison.sample());
472    }
473
474    #[test]
475    fn test_partial_eq() {
476        let a = Uncertain::new(|| 5.0);
477        let b = Uncertain::new(|| 5.0);
478        let c = Uncertain::new(|| 10.0);
479
480        assert_eq!(a, b);
481        assert_ne!(a, c);
482    }
483
484    #[test]
485    fn test_partial_ord() {
486        let smaller = Uncertain::new(|| 1.0);
487        let larger = Uncertain::new(|| 2.0);
488
489        assert!(smaller < larger);
490        assert!(larger > smaller);
491        assert!(smaller.partial_cmp(&larger).is_some());
492    }
493
494    #[test]
495    fn test_partial_ord_equal() {
496        let a = Uncertain::new(|| 5.0);
497        let b = Uncertain::new(|| 5.0);
498
499        assert!(a.partial_cmp(&b).is_some());
500        assert!(b.partial_cmp(&a).is_some());
501    }
502
503    #[test]
504    fn test_debug_formatting() {
505        let uncertain = Uncertain::new(|| 42);
506        let debug_str = format!("{uncertain:?}");
507        assert!(debug_str.contains("Uncertain"));
508        assert!(debug_str.contains("42"));
509    }
510
511    #[test]
512    fn test_clone() {
513        let original = Uncertain::new(|| 123.0_f64);
514        let cloned = original.clone();
515
516        assert!((original.sample() - cloned.sample()).abs() < f64::EPSILON);
517        assert!((original.sample() - 123.0_f64).abs() < f64::EPSILON);
518        assert!((cloned.sample() - 123.0_f64).abs() < f64::EPSILON);
519    }
520
521    #[test]
522    fn test_with_random_sampler() {
523        use rand::random;
524        let uncertain = Uncertain::new(random::<f64>);
525
526        // Should generate different values (with very high probability)
527        let sample1 = uncertain.sample();
528        let sample2 = uncertain.sample();
529        // Very unlikely they'll be exactly equal for random f64
530        assert!((0.0..=1.0).contains(&sample1));
531        assert!((0.0..=1.0).contains(&sample2));
532    }
533
534    #[test]
535    fn test_map_preserves_uncertainty() {
536        use rand::random;
537        let base = Uncertain::new(random::<f64>);
538        let transformed = base.map(|x| x * 100.0);
539
540        let sample = transformed.sample();
541        assert!((0.0..=100.0).contains(&sample));
542    }
543
544    #[test]
545    fn test_gt_method_api() {
546        let speed = Uncertain::new(|| 65.0);
547        let speeding_evidence = speed.gt(60.0);
548        assert!(speeding_evidence.sample()); // 65 > 60
549    }
550
551    #[test]
552    fn test_lt_method_api() {
553        let temperature = Uncertain::new(|| -5.0);
554        let freezing_evidence = temperature.lt(0.0);
555        assert!(freezing_evidence.sample()); // -5 < 0
556    }
557
558    #[test]
559    fn test_ge_method_api() {
560        let value = Uncertain::new(|| 10.0);
561        let evidence = value.ge(10.0);
562        assert!(evidence.sample()); // 10 >= 10
563    }
564
565    #[test]
566    fn test_le_method_api() {
567        let value = Uncertain::new(|| 5.0);
568        let evidence = value.le(10.0);
569        assert!(evidence.sample()); // 5 <= 10
570    }
571
572    #[test]
573    fn test_eq_value_method_api() {
574        let value = Uncertain::new(|| 42);
575        let evidence = value.eq_value(42);
576        assert!(evidence.sample()); // 42 == 42
577    }
578
579    #[test]
580    fn test_ne_value_method_api() {
581        let value = Uncertain::new(|| 42);
582        let evidence = value.ne_value(0);
583        assert!(evidence.sample()); // 42 != 0
584    }
585
586    #[test]
587    fn test_readme_example_api() {
588        // Test the exact API shown in the README
589        let speed = Uncertain::normal(55.2, 5.0);
590        let speeding_evidence = speed.gt(60.0);
591
592        // This should compile and work (the exact API from README)
593        let _result = speeding_evidence.probability_exceeds(0.95);
594
595        // Test with a value that's definitely over the threshold
596        let high_speed = Uncertain::point(70.0);
597        let high_speed_evidence = high_speed.gt(60.0);
598        assert!(high_speed_evidence.probability_exceeds(0.95));
599    }
600}