1use crate::{Domain, ErrorKind, Result};
3use ordered_float::NotNan;
4use rand::distributions::Distribution;
5use rand::Rng;
6use std::num::NonZeroU64;
7
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub struct VecDomain<T>(pub Vec<T>);
11
12impl<T: Domain> Domain for VecDomain<T> {
13 type Point = Vec<T::Point>;
14}
15
16impl<T> Distribution<Vec<T::Point>> for VecDomain<T>
17where
18 T: Domain + Distribution<<T as Domain>::Point>,
19{
20 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T::Point> {
21 self.0.iter().map(|t| t.sample(rng)).collect()
22 }
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct CategoricalDomain {
28 cardinality: NonZeroU64,
29}
30impl CategoricalDomain {
31 pub fn new(cardinality: u64) -> Result<Self> {
37 let cardinality = track_assert_some!(NonZeroU64::new(cardinality), ErrorKind::InvalidInput);
38 Ok(Self { cardinality })
39 }
40
41 pub const fn cardinality(&self) -> NonZeroU64 {
43 self.cardinality
44 }
45}
46impl Domain for CategoricalDomain {
47 type Point = u64;
48}
49impl From<NonZeroU64> for CategoricalDomain {
50 fn from(cardinality: NonZeroU64) -> Self {
51 Self { cardinality }
52 }
53}
54impl Distribution<u64> for CategoricalDomain {
55 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
56 rng.gen_range(0..self.cardinality.get())
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub struct DiscreteDomain {
63 size: NonZeroU64,
64}
65impl DiscreteDomain {
66 pub fn new(size: u64) -> Result<Self> {
72 let size = track_assert_some!(NonZeroU64::new(size), ErrorKind::InvalidInput);
73 Ok(Self { size })
74 }
75
76 pub const fn size(&self) -> NonZeroU64 {
78 self.size
79 }
80}
81impl Domain for DiscreteDomain {
82 type Point = u64;
83}
84impl From<NonZeroU64> for DiscreteDomain {
85 fn from(size: NonZeroU64) -> Self {
86 Self { size }
87 }
88}
89impl Distribution<u64> for DiscreteDomain {
90 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
91 rng.gen_range(0..self.size.get())
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct ContinuousDomain {
98 low: NotNan<f64>,
99 high: NotNan<f64>,
100}
101impl ContinuousDomain {
102 pub fn new(low: f64, high: f64) -> Result<Self> {
114 track_assert!(low.is_finite(), ErrorKind::InvalidInput; low, high);
115 track_assert!(high.is_finite(), ErrorKind::InvalidInput; low, high);
116 track_assert!(low < high, ErrorKind::InvalidInput; low, high);
117 track_assert!((high - low).is_finite(), ErrorKind::InvalidInput; low, high);
118
119 Ok(unsafe {
120 Self {
121 low: NotNan::unchecked_new(low),
122 high: NotNan::unchecked_new(high),
123 }
124 })
125 }
126
127 pub fn low(&self) -> f64 {
129 self.low.into_inner()
130 }
131
132 pub fn high(&self) -> f64 {
134 self.high.into_inner()
135 }
136
137 pub fn size(&self) -> f64 {
139 self.high() - self.low()
140 }
141}
142impl Domain for ContinuousDomain {
143 type Point = f64;
144}
145impl Distribution<f64> for ContinuousDomain {
146 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
147 rng.gen_range(self.low()..self.high())
148 }
149}