spaces/
partition.rs

1use crate::{Interval, prelude::*};
2use std::{cmp, fmt, ops::Range};
3
4/// Finite, uniformly partitioned interval.
5#[derive(Debug, Clone, Copy)]
6#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
7pub struct Equipartition<const N: usize> {
8    lb: f64,
9    ub: f64,
10}
11
12impl<const N: usize> Equipartition<N> {
13    pub fn new(lb: f64, ub: f64) -> Equipartition<N> {
14        if N == 0 {
15            panic!("A partition must have a number partitions of 1 or greater.")
16        }
17
18        Equipartition { lb, ub, }
19    }
20
21    pub fn from_interval<I: Into<Interval>>(d: I) -> Equipartition<N> {
22        let interval = d.into();
23
24        Equipartition {
25            lb: interval.lb.expect("Must be a bounded interval."),
26            ub: interval.ub.expect("Must be a bounded interval."),
27        }
28    }
29
30    #[inline]
31    pub fn n_partitions(&self) -> usize { N }
32
33    #[inline]
34    pub fn partition_width(&self) -> f64 { (self.ub - self.lb) / N as f64 }
35
36    pub fn centres(&self) -> [f64; N] {
37        let w = self.partition_width();
38        let hw = w / 2.0;
39        let mut output = [f64::default(); N];
40
41        for i in 0..N {
42            output[i] = self.lb + w * ((i + 1) as f64) - hw;
43        }
44
45        output
46    }
47
48    pub fn edges(&self) -> [f64; N] {
49        let w = self.partition_width();
50        let mut output = [f64::default(); N];
51
52        for i in 0..N {
53            output[i] = self.lb + w * (i as f64);
54        }
55
56        output
57    }
58
59    pub fn to_partition(&self, val: f64) -> usize {
60        let clipped = clip!(self.lb, val, self.ub);
61
62        let diff = clipped - self.lb;
63        let range = self.ub - self.lb;
64
65        let i = ((N as f64) * diff / range).floor() as usize;
66
67        if i >= N { N - 1 } else { i }
68    }
69}
70
71impl<const N: usize> Space for Equipartition<N> {
72    const DIM: usize = 1;
73
74    type Value = usize;
75
76    fn card(&self) -> Card { Card::Finite(N) }
77
78    fn contains(&self, val: &usize) -> bool { *val < N }
79}
80
81impl<const N: usize> OrderedSpace for Equipartition<N> {
82    fn min(&self) -> Option<usize> { Some(0) }
83
84    fn max(&self) -> Option<usize> { Some(N - 1) }
85}
86
87impl<const N: usize> FiniteSpace for Equipartition<N> {
88    fn to_ordinal(&self) -> Range<Self::Value> { 0..N }
89}
90
91impl<const N: usize, const M: usize> cmp::PartialEq<Equipartition<M>> for Equipartition<N> {
92    fn eq(&self, other: &Equipartition<M>) -> bool {
93        N.eq(&M) && self.lb.eq(&other.lb) && self.ub.eq(&other.ub)
94    }
95}
96
97impl<const N: usize> fmt::Display for Equipartition<N> {
98    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99        match N {
100            n if n == 1 => write!(f, "{{{} = x0, x1 = {}}}", self.lb, self.ub),
101            n if n == 2 => write!(f, "{{{} = x0, x1, x2 = {}}}", self.lb, self.ub),
102            n => write!(f, "{{{} = x0, x1, ..., x{} = {}}}", self.lb, n, self.ub),
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[cfg(feature = "serialize")]
112    extern crate serde_test;
113    #[cfg(feature = "serialize")]
114    use self::serde_test::{assert_tokens, Token};
115
116    #[test]
117    fn test_from_interval() {
118        assert_eq!(
119            Equipartition::<5>::new(0.0, 5.0),
120            Equipartition::<5>::from_interval(Interval::bounded(0.0, 5.0))
121        );
122    }
123
124    #[test]
125    fn test_density() {
126        assert_eq!(Equipartition::<5>::new(0.0, 5.0).n_partitions(), 5);
127        assert_eq!(Equipartition::<10>::new(0.0, 5.0).n_partitions(), 10);
128        assert_eq!(Equipartition::<100>::new(-5.0, 5.0).n_partitions(), 100);
129    }
130
131    #[test]
132    fn test_partition_width() {
133        assert_eq!(Equipartition::<5>::new(0.0, 5.0).partition_width(), 1.0);
134        assert_eq!(Equipartition::<10>::new(0.0, 5.0).partition_width(), 0.5);
135        assert_eq!(Equipartition::<10>::new(-5.0, 5.0).partition_width(), 1.0);
136    }
137
138    #[test]
139    fn test_centres() {
140        assert_eq!(
141            Equipartition::new(0.0, 5.0).centres(),
142            [0.5, 1.5, 2.5, 3.5, 4.5]
143        );
144
145        assert_eq!(
146            Equipartition::new(-5.0, 5.0).centres(),
147            [-4.0, -2.0, 0.0, 2.0, 4.0]
148        );
149    }
150
151    #[test]
152    fn test_to_partition() {
153        let d = Equipartition::<6>::new(0.0, 5.0);
154
155        assert_eq!(d.to_partition(-1.0), 0);
156        assert_eq!(d.to_partition(0.0), 0);
157        assert_eq!(d.to_partition(1.0), 1);
158        assert_eq!(d.to_partition(2.0), 2);
159        assert_eq!(d.to_partition(3.0), 3);
160        assert_eq!(d.to_partition(4.0), 4);
161        assert_eq!(d.to_partition(5.0), 5);
162        assert_eq!(d.to_partition(6.0), 5);
163    }
164
165    #[test]
166    fn test_dim() {
167        assert_eq!(Equipartition::<1>::DIM, 1);
168        assert_eq!(Equipartition::<5>::DIM, 1);
169        assert_eq!(Equipartition::<10>::DIM, 1);
170    }
171
172    #[test]
173    fn test_card() {
174        fn check<const N: usize>(lb: f64, ub: f64) {
175            let d = Equipartition::<N>::new(lb, ub);
176
177            assert_eq!(d.card(), Card::Finite(N));
178        }
179
180        check::<5>(0.0, 5.0);
181        check::<5>(-5.0, 0.0);
182        check::<10>(-5.0, 5.0);
183    }
184
185    #[test]
186    fn test_to_ordinal() {
187        fn check<const N: usize>(lb: f64, ub: f64) {
188            let d = Equipartition::<N>::new(lb, ub);
189
190            assert_eq!(d.to_ordinal(), 0..N);
191        }
192
193        check::<5>(0.0, 5.0);
194        check::<5>(-5.0, 0.0);
195        check::<10>(-5.0, 5.0);
196    }
197
198    #[cfg(feature = "serialize")]
199    #[test]
200    fn test_serialisation() {
201        fn check(lb: f64, ub: f64, n_partitions: usize) {
202            let d = Equipartition::new(lb, ub, n_partitions);
203
204            assert_tokens(
205                &d,
206                &[
207                    Token::Struct {
208                        name: "Equipartition",
209                        len: 3,
210                    },
211                    Token::Str("lb"),
212                    Token::F64(lb),
213                    Token::Str("ub"),
214                    Token::F64(ub),
215                    Token::Str("n_partitions"),
216                    Token::U64(n_partitions as u64),
217                    Token::StructEnd,
218                ],
219            );
220        }
221
222        check(0.0, 5.0, 5);
223        check(-5.0, 5.0, 10);
224        check(-5.0, 0.0, 5);
225    }
226}