yamakan/
domains.rs

1//! Parameter search domains.
2use crate::{Domain, ErrorKind, Result};
3use ordered_float::NotNan;
4use rand::distributions::Distribution;
5use rand::Rng;
6use std::num::NonZeroU64;
7
8/// Vector domain.
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub struct VecDomain<T>(pub Vec<T>);
11
12impl<T: Domain> Domain for VecDomain<T> {
13    type Point = Vec<T::Point>;
14}
15
16impl<T> Distribution<Vec<T::Point>> for VecDomain<T>
17where
18    T: Domain + Distribution<<T as Domain>::Point>,
19{
20    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T::Point> {
21        self.0.iter().map(|t| t.sample(rng)).collect()
22    }
23}
24
25/// Categorical domain.
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct CategoricalDomain {
28    cardinality: NonZeroU64,
29}
30impl CategoricalDomain {
31    /// Makes a new `CategoricalDomain` instance.
32    ///
33    /// # Errors
34    ///
35    /// If `cardinality` is `0`, this function returns an `ErrorKind::InvalidInput` error.
36    pub fn new(cardinality: u64) -> Result<Self> {
37        let cardinality = track_assert_some!(NonZeroU64::new(cardinality), ErrorKind::InvalidInput);
38        Ok(Self { cardinality })
39    }
40
41    /// Returns the cardinality of this domain.
42    pub const fn cardinality(&self) -> NonZeroU64 {
43        self.cardinality
44    }
45}
46impl Domain for CategoricalDomain {
47    type Point = u64;
48}
49impl From<NonZeroU64> for CategoricalDomain {
50    fn from(cardinality: NonZeroU64) -> Self {
51        Self { cardinality }
52    }
53}
54impl Distribution<u64> for CategoricalDomain {
55    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
56        rng.gen_range(0..self.cardinality.get())
57    }
58}
59
60/// Discrete numerical domain.
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub struct DiscreteDomain {
63    size: NonZeroU64,
64}
65impl DiscreteDomain {
66    /// Makes a new `DiscreteDomain` instance.
67    ///
68    /// # Errors
69    ///
70    /// If `size` is `0`, this function returns an `ErrorKind::InvalidInput` error.
71    pub fn new(size: u64) -> Result<Self> {
72        let size = track_assert_some!(NonZeroU64::new(size), ErrorKind::InvalidInput);
73        Ok(Self { size })
74    }
75
76    /// Returns the size of this domain.
77    pub const fn size(&self) -> NonZeroU64 {
78        self.size
79    }
80}
81impl Domain for DiscreteDomain {
82    type Point = u64;
83}
84impl From<NonZeroU64> for DiscreteDomain {
85    fn from(size: NonZeroU64) -> Self {
86        Self { size }
87    }
88}
89impl Distribution<u64> for DiscreteDomain {
90    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
91        rng.gen_range(0..self.size.get())
92    }
93}
94
95/// Continuous numerical domain.
96#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct ContinuousDomain {
98    low: NotNan<f64>,
99    high: NotNan<f64>,
100}
101impl ContinuousDomain {
102    /// Makes a new `ContinuousDomain` instance.
103    ///
104    /// The returned instance represents a half-closed interval, i.e., `[low..high)`.
105    ///
106    /// # Errors
107    ///
108    /// If one of the following conditions is satisfied, this function returns an `ErrorKind::InvalidInput` error:
109    ///
110    /// - `low` or `high` is not a finite number
111    /// - `low >= high`
112    /// - `high - low` is not a finite number
113    pub fn new(low: f64, high: f64) -> Result<Self> {
114        track_assert!(low.is_finite(), ErrorKind::InvalidInput; low, high);
115        track_assert!(high.is_finite(), ErrorKind::InvalidInput; low, high);
116        track_assert!(low < high, ErrorKind::InvalidInput; low, high);
117        track_assert!((high - low).is_finite(), ErrorKind::InvalidInput; low, high);
118
119        Ok(unsafe {
120            Self {
121                low: NotNan::unchecked_new(low),
122                high: NotNan::unchecked_new(high),
123            }
124        })
125    }
126
127    /// Returns the lower bound of this domain.
128    pub fn low(&self) -> f64 {
129        self.low.into_inner()
130    }
131
132    /// Returns the upper bound of this domain.
133    pub fn high(&self) -> f64 {
134        self.high.into_inner()
135    }
136
137    /// Returns the size of this domain.
138    pub fn size(&self) -> f64 {
139        self.high() - self.low()
140    }
141}
142impl Domain for ContinuousDomain {
143    type Point = f64;
144}
145impl Distribution<f64> for ContinuousDomain {
146    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
147        rng.gen_range(self.low()..self.high())
148    }
149}