test_sampler/
lib.rs

1//! Contains tools to perform unit testing of sampling algorithms.
2//!
3//! It has been developed particularly to help with the development of Monte
4//! Carlo particle transport codes, where a large number of various sampling
5//! procedures is required to stochastically simulate physical interactions of
6//! radiation with matter.
7//!
8//! In general models are described with differential cross-sections
9//! $\frac{d\sigma}{dE'd\Omega}(E)$, which provide with the shape of probability
10//! density function. In general normalisation is difficult to get without
11//! complex integration.
12//!
13//! For that reason this package is composed of two parts:
14//! - [FunctionSampler] which allows to (inefficiently) draw samples from
15//!   a non-normalised pdf shape function
16//! - Suite of statistical tests [crate::stat_tests], which allow to verify that
17//!   samples from a tested distribution match the one generated with [FunctionSampler]
18//!
19//! Thus to verify sampling one needs to:
20//! - Verify shape function with deterministic unit tests
21//! - Compare sampling procedure against reference values from [FunctionSampler]
22//!   using statistical tests
23//!
24//! Note that as a result of statistical uncertainty and variable *power* of
25//! statistical tests for different defects and sample populations the sampling
26//! unit tests cannot ever provide with the same level of certainty as the
27//! deterministic one. Also the appropriate number of samples and type(s) of
28//! the test(s) will depend on a particular application.
29//!
30//! # Example
31//!
32//! Let us verify simple inversion sampling of $f(x) = x$ on $\[0;1\]$,
33//! for which we can generate samples with $\hat{f} = \sqrt{r}$ where $r$ is
34//! uniformly distributed random number.
35//!
36//! ```
37//! use test_sampler::FunctionSampler;
38//! use rand::prelude::*;
39//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
40//!
41//! // Seed rngs
42//! let mut rng1 = StdRng::seed_from_u64(87674);
43//! let mut rng2 = StdRng::seed_from_u64(87674);
44//!
45//! // Draw reference samples from the FunctionSampler
46//! let support = 0.0..1.0;
47//! let num_bins = 30;
48//!
49//! let reference_dist = FunctionSampler::new(|x| x, support, num_bins)?;
50//! let s_ref : Vec<f64> = rng1.sample_iter(&reference_dist).take(100).collect();
51//!
52//! // Samples to test
53//! let s : Vec<f64> = (0..100).map(|_| rng2.gen()).map(|r: f64| r.sqrt()).collect();
54//!
55//! // Perform tests
56//! // Vectors of samples will be moved inside the procedures
57//! // It is necessary since the samples must be sorted (and mutated)
58//! let ks_res = test_sampler::stat_tests::ks2_test(s_ref.clone(), s.clone())?;
59//! let kup_res = test_sampler::stat_tests::kuiper2_test(s_ref.clone(), s.clone())?;
60//! let ad_res = test_sampler::stat_tests::ad2_test(s_ref, s)?;
61//!
62//! // Check test results
63//! assert!(ks_res.p_value() > 0.05);
64//! assert!(kup_res.p_value() > 0.05);
65//! assert!(ad_res.p_value() > 0.05);
66//!
67//! # Ok(()) }
68//! ```
69//!
70use argmin::core::CostFunction;
71use is_sorted::IsSorted;
72use std::ops::Range;
73use thiserror::Error;
74
75pub mod stat_tests;
76
77/// Error raised when the setup of a sampling distribution has failed
78///
79#[derive(Error, Debug)]
80pub enum SetupError {
81    /// Grid for tabulated values si not sorted
82    #[error("Values in a grid are not sorted")]
83    UnsortedGrid,
84    /// Negative entries were found in probability density function (pdf)
85    #[error("Negative values present in probability density function")]
86    NegativePdf,
87    /// Length of vectors that form a table is not the same
88    #[error("Lengths of arrays to form a table are different")]
89    LengthMismatch,
90    /// Wraps errors from failed optimisation by [argmin]
91    #[error("Optimisation with argmin has failed")]
92    OptimisationError(#[from] argmin::core::Error),
93    /// Number of bins to construct the tabulated data (`0`) is lower then required  (`1`)
94    #[error("Insufficient number of bins: {0}, must have at least {1}")]
95    InsufficientBins(usize, usize),
96    /// Was given empty range to represent a mon-empty interval
97    #[error("Got empty range where non-empty is required")]
98    EmptyRange,
99}
100
101/// Distribution described by non-normalised histogram
102///
103/// The distribution is given as a table of `x` and `pdf` which follows histogram
104/// interpolation. For `x ∈ [xᵢ₊₁; xᵢ]` probability if `pdf(x) = pdfᵢ`. The
105/// cumulative distribution function becomes piece-wise linear which makes
106/// sampling form the table quite easy.
107///
108/// To support better approximation of different pdfs, the grid is not
109/// equal-spaced in general. Hence binary search is needed to find correct bin.
110///
111#[derive(Debug, Clone)]
112pub struct HistogramDistribution {
113    x: Vec<f64>,
114    pdf: Vec<f64>,
115    cdf: Vec<f64>,
116}
117
118///
119/// Search sorted grid of values and find the lower bound
120///
121/// Returns `i` such that `grid[i] <= val < grid[i + 1]`
122///
123/// Local function required to search the [HistogramDistribution]
124///
125fn search_sorted<T>(grid: &[T], val: T) -> Option<usize>
126where
127    T: PartialOrd,
128{
129    let first = grid.first().unwrap();
130    let last = grid.last().unwrap();
131
132    if !(first..last).contains(&&val) {
133        return None;
134    }
135
136    match grid.binary_search_by(|k| k.partial_cmp(&val).unwrap()) {
137        Ok(j) => Some(j),
138        Err(j) => Some(j - 1),
139    }
140}
141
142/// Calculate non-normalised cdf from a histogram probability density
143///
144/// # Panics
145/// if x and pdf do not match in length
146///
147fn histogram_cdf(x: &[f64], pdf: &[f64]) -> Vec<f64> {
148    if x.len() != pdf.len() {
149        panic! {"Length mismatch"}
150    }
151
152    let mut cdf = vec![0.0];
153    cdf.reserve(x.len());
154
155    let dx_iter = x.windows(2).map(|w| w[1] - w[0]);
156
157    for (dx, p) in std::iter::zip(dx_iter, pdf.iter()) {
158        // We know CDF is never empty
159        let top = cdf.last().unwrap();
160        cdf.push(top + *p * dx)
161    }
162    cdf
163}
164
165impl HistogramDistribution {
166    /// Create a new instance of the histogram distribution
167    ///
168    /// The cumulative distribution function will be calculated.
169    ///
170    /// Condition `x.len() == pdf.len() > 1`, must be met.
171    /// Thus the last value in the `pdf` vector will be ignored.
172    ///
173    pub fn new(x: Vec<f64>, pdf: Vec<f64>) -> Result<Self, SetupError> {
174        if !IsSorted::is_sorted(&mut x.iter()) {
175            return Err(SetupError::UnsortedGrid);
176        } else if pdf.iter().any(|v| *v < 0.0) {
177            return Err(SetupError::NegativePdf);
178        } else if x.len() != pdf.len() {
179            return Err(SetupError::LengthMismatch);
180        } else if x.len() <= 1 {
181            return Err(SetupError::InsufficientBins(x.len(), 2));
182        }
183
184        let cdf = histogram_cdf(&x, &pdf);
185        Ok(Self { x, pdf, cdf })
186    }
187
188    /// Sample a value from the histogram and return the value of non-normalised probability
189    ///
190    /// # Result
191    /// Tuple `(s, p)` where `s` is the sample and `p` probability in the bin
192    ///
193    /// We need a way to sample while returning probability value in the bin as well
194    /// to implement rejection sampling scheme without repeating a binary search of the grid.
195    ///
196    pub fn sample_with_value<RNG>(&self, rng: &mut RNG) -> (f64, f64)
197    where
198        RNG: rand::Rng + ?Sized,
199    {
200        // We know cdf is not empty
201        let val = rng.gen_range(0.0..*self.cdf.last().unwrap());
202        let idx = search_sorted(&self.cdf, val).unwrap();
203
204        let x0 = self.x[idx];
205        let p0 = self.pdf[idx];
206        let c0 = self.cdf[idx];
207        ((val - c0) / p0 + x0, p0)
208    }
209}
210
211/// Draws samples from the histogram distribution
212///
213///
214impl rand::distributions::Distribution<f64> for HistogramDistribution {
215    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
216        self.sample_with_value(rng).0
217    }
218}
219
220///
221/// Wrap a function to maximise
222///
223/// We need a separate struct to use [argmin] library, because we need to implement
224/// some Traits to use it as optimisation problem with other `argmin` components.
225///
226/// Since by convention, objective are minimised, we also need to take the
227/// negative of `function` as the cost.
228///
229struct FlipSign<T: Fn(f64) -> f64> {
230    pub function: T,
231}
232
233impl<T> CostFunction for FlipSign<T>
234where
235    T: Fn(f64) -> f64,
236{
237    type Param = f64;
238    type Output = f64;
239
240    fn cost(&self, param: &Self::Param) -> Result<Self::Output, argmin::core::Error> {
241        Ok(-(self.function)(*param))
242    }
243}
244
245///
246/// Creates linearly spaced grid between `start` and `end` of size `n`
247///
248/// ```
249/// # use test_sampler::linspace;
250/// assert_eq!(vec![1.0, 2.0, 3.0], linspace(1.0, 3.0, 3));
251/// assert_eq!(vec![3.0, 2.0, 1.0], linspace(3.0, 1.0, 3));
252/// assert_eq!(vec![3.0, 3.0, 3.0], linspace(3.0, 3.0, 3));
253/// ```
254///
255/// # Panics
256/// If number of points `n` is 0 or 1
257///
258pub fn linspace(start: f64, end: f64, n: usize) -> Vec<f64> {
259    if n < 2 {
260        panic! {"Grid cannot have {n} values. At least 2 are required."}
261    }
262
263    let delta = (end - start) / (n - 1) as f64;
264    (0..n).map(|i| start + delta * i as f64).collect()
265}
266
267/// Distribution described by a non-negative shape function on an interval
268///
269/// Allows sampling from a generic distribution described by some shape function $f(x)$.
270/// The function does not need to be normalised i.e. $ \int f(s) d(s) \ne 1 $ in general
271///
272/// It is intended to be used as a reference distribution for verification of
273/// more efficient sampling algorithms.
274///
275/// # Sampling procedure
276///
277/// To generate the samples a relatively expensive setup step is required. The user
278/// provides an interval $[x_0, x_1]$ which is a support of the shape function $f(x)$.
279/// The support is then subdivided into number of bins. In each a local maximum
280/// is found by numerical means. This allows to create a histogram approximation
281/// of the probability distribution, that 'tops' the actual distribution. Hence,
282/// we may draw samples from the histogram and use rejection scheme to obtain
283/// the distribution described by $f(x)$.
284///
285/// Since the numerical maximisation is associated with some tolerance, a safety
286/// factor of 5% is applied on the local maxima to ensure that the supremum criterion
287/// required for rejection sampling is met.
288///
289/// Note that for very sharp functions (large $\frac{df}{dx} $) the safety factor may be
290/// insufficient. Also a general check if  $f(x) \ge 0 ~\forall~ x \in [x_0; x_1]$
291/// is met is not feasible. Hence sampling may **panic** if either of the conditions occurs.
292///
293/// # Usage
294///
295/// The distribution is integrated with [rand] package and can be used to construct
296/// a sampling iterator as follows:
297/// ```
298/// # use test_sampler::FunctionSampler;
299/// use rand::{self, Rng};
300///
301/// # fn main() -> Result<(), test_sampler::SetupError> {
302/// let dist = FunctionSampler::new(|x| -x*x + x, 0.0..1.0, 30)?;
303/// let samples = rand::thread_rng().sample_iter(&dist).take(10);
304/// # Ok(())
305/// # }
306/// ```
307///
308#[derive(Debug, Clone)]
309pub struct FunctionSampler<T: Fn(f64) -> f64> {
310    function: T,
311    hist: HistogramDistribution,
312}
313
314impl<T> FunctionSampler<T>
315where
316    T: Fn(f64) -> f64,
317{
318    /// Safety factor to increase bin maxima to protect against error due
319    /// to tolerance of numerical optimisation
320    pub const SAFETY_FACTOR: f64 = 1.05;
321
322    /// Helper function to find maximum in a given bin
323    ///
324    fn maximise(function: &T, start: f64, end: f64) -> Result<f64, SetupError> {
325        let problem = FlipSign { function };
326        let solver = argmin::solver::brent::BrentOpt::new(start, end);
327        let res = argmin::core::Executor::new(problem, solver).run()?;
328
329        Ok(-res.state().cost)
330    }
331
332    /// New sampler from components
333    ///
334    /// # Arguments
335    /// - `function` - A function `Fn(f64) -> f64` that is non-negative on `range`
336    /// - `range` - Support of the probability distribution with the shape of `function`
337    /// - `bins` - Number of bins to construct topping histogram (at least 1)
338    ///
339    pub fn new(function: T, range: Range<f64>, bins: usize) -> Result<Self, SetupError> {
340        if bins == 0 {
341            return Err(SetupError::InsufficientBins(bins, 1));
342        } else if range.is_empty() {
343            return Err(SetupError::EmptyRange);
344        }
345
346        // Create subdivision grid
347        let grid = linspace(range.start, range.end, bins + 1);
348
349        // Get maxima in each window
350        // We need the vector of maxima to match
351        let mut maxima = grid
352            .windows(2)
353            .map(|x| Self::maximise(&function, x[0], x[1]).map(|x| Self::SAFETY_FACTOR * x))
354            .collect::<Result<Vec<f64>, SetupError>>()?;
355        maxima.push(*maxima.last().unwrap());
356
357        // Construct topping distribution
358        let hist = HistogramDistribution::new(grid, maxima)?;
359
360        // Profit
361        Ok(Self { function, hist })
362    }
363}
364
365/// Draws samples from the FunctionSampler
366///
367/// # Panics
368/// Sampling may panic if the shape function has  negative values or upper bound
369/// is not is not fulfilled. This may happen if gradient of the shape function
370/// is strong and the safety factor on the tolerance of optimisation was
371/// insufficient
372///
373impl<T> rand::distributions::Distribution<f64> for FunctionSampler<T>
374where
375    T: Fn(f64) -> f64,
376{
377    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
378        // Draw sample from the histogram
379        loop {
380            let (sample, p_top) = self.hist.sample_with_value(rng);
381            let p_val = (self.function)(sample);
382            if p_top < p_val {
383                panic!("Upper bound {p_top} is lower than {p_val} at {sample}");
384            } else if p_val < 0.0 {
385                panic!("Negative value {p_val} at {sample}")
386            }
387            if p_val / p_top > rng.gen() {
388                return sample;
389            }
390        }
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use approx;
398    use rand::{Rng, SeedableRng};
399
400    #[test]
401    fn test_histogram_calculation() {
402        let x = vec![-1.0, 0.5, 3.0];
403        let pdf = vec![0.1, 0.2, 0.3];
404        let cdf = vec![0.0, 0.15, 0.65];
405
406        let cdf_calc = histogram_cdf(&x, &pdf);
407
408        for (v_ref, v_calc) in std::iter::zip(cdf, cdf_calc) {
409            approx::assert_relative_eq!(v_ref, v_calc);
410        }
411    }
412
413    #[test]
414    #[should_panic]
415    fn test_invalid_histogram_calculation() {
416        let _ = histogram_cdf(&[0.0, 1.0, 3.0], &[0.1, 0.1]);
417    }
418
419    #[test]
420    fn test_histogram_distribution() {
421        let x = vec![-1.0, 0.5, 3.0];
422        let pdf = vec![0.1, 0.2, 0.3];
423        let _ = HistogramDistribution::new(x, pdf).unwrap();
424    }
425
426    #[test]
427    fn test_search_of_sorted_grid() {
428        let grid = [-0.5, 1.0, 2.0, 6.0];
429
430        // Out of range
431        assert_eq!(None, search_sorted(&grid, -0.6));
432        assert_eq!(None, search_sorted(&grid, 6.0));
433        assert_eq!(None, search_sorted(&grid, 9.0));
434
435        // In range
436        assert_eq!(Some(0), search_sorted(&grid, -0.5));
437        assert_eq!(Some(0), search_sorted(&grid, 0.0));
438
439        assert_eq!(Some(1), search_sorted(&grid, 1.0));
440        assert_eq!(Some(1), search_sorted(&grid, 1.5));
441
442        assert_eq!(Some(2), search_sorted(&grid, 5.0));
443    }
444
445    #[test]
446    fn test_histogram_pdf_sampling() {
447        let support = [-1.0, 0.5, 3.0, 4.0];
448        let pdf = [0.1, 0.2, 0.35, 0.35];
449        let cdf = [0.0, 0.15, 0.65, 1.0];
450        let n_samples = 10000;
451        let mut rng = rand::rngs::StdRng::seed_from_u64(87674);
452
453        let dist = HistogramDistribution::new(support.into(), pdf.into()).unwrap();
454
455        let samples = (0..n_samples)
456            .map(|_| rng.sample(&dist))
457            .collect::<Vec<_>>();
458
459        let ks_res = stat_tests::ks1_test(
460            |x| {
461                let idx = search_sorted(&support, *x).unwrap();
462                let x0 = support[idx];
463                let x1 = support[idx + 1];
464                let c0 = cdf[idx];
465                let c1 = cdf[idx + 1];
466                (x - x0) / (x1 - x0) * (c1 - c0) + c0
467            },
468            samples,
469        )
470        .unwrap();
471
472        // Print the test results in case of a failure
473        println!("{:?}", ks_res);
474        assert!(ks_res.p_value() > 0.01)
475    }
476
477    #[test]
478    fn test_histogram_distribution_errors() {
479        let x = vec![-1.0, 0.5, 3.0];
480        let pdf = vec![0.1, 0.2, 0.3];
481
482        assert!(
483            HistogramDistribution::new(vec![1.0, 0.5, 3.0], pdf.clone()).is_err(),
484            "Failed to detect non-sorted grid"
485        );
486        assert!(
487            HistogramDistribution::new(x.clone(), vec![0.1, -0.2, 0.3]).is_err(),
488            "Failed to detect -ve pdf"
489        );
490        assert!(
491            HistogramDistribution::new(x.clone(), vec![0.1, 0.3]).is_err(),
492            "Failed to detect length mistmatch"
493        );
494        assert!(
495            HistogramDistribution::new(vec![0.1], vec![0.1]).is_err(),
496            "Failed to detect too short vectors"
497        );
498    }
499
500    #[test]
501    fn test_function_sampler_sampling() {
502        let mut rng = rand::rngs::StdRng::seed_from_u64(87674);
503        let dist = FunctionSampler::new(|x| -x * x + x, 0.0..1.0, 30).unwrap();
504        let n_samples = 10000;
505
506        let samples = (0..n_samples)
507            .map(|_| rng.sample(&dist))
508            .collect::<Vec<_>>();
509
510        let ks_res = stat_tests::ks1_test(|x| 3.0 * x * x - 2.0 * x * x * x, samples).unwrap();
511        // Print the test results in case of a failure
512        println!("{:?}", ks_res);
513        assert!(ks_res.p_value() > 0.01)
514    }
515
516    #[test]
517    fn test_function_sampler_setup_errors() {
518        assert!(
519            FunctionSampler::new(|x| -x * x + x, 1.0..0.0, 30).is_err(),
520            "Failed to detect empty range"
521        );
522        assert!(
523            FunctionSampler::new(|x| -x * x + x, 0.0..1.0, 0).is_err(),
524            "Failed to detect insufficient number of bins"
525        );
526        assert!(
527            FunctionSampler::new(|x| -x * x + x - 0.2, 0.0..1.0, 30).is_err(),
528            "Failed to detect negative maxima in thr bins"
529        );
530    }
531
532    #[test]
533    #[should_panic]
534    fn test_function_sampler_negative_pdf() {
535        let mut rng = rand::rngs::StdRng::seed_from_u64(87674);
536
537        // We select only single bin so that negative pdf is hidden from the histogram
538        let dist = FunctionSampler::new(|x| -x * x + x - 0.1, 0.0..1.0, 1).unwrap();
539
540        // Will panic on sampling
541        let _samples = (0..100).map(|_| rng.sample(&dist)).collect::<Vec<_>>();
542    }
543}