1use crate::{Interval, prelude::*};
2use std::{cmp, fmt, ops::Range};
3
4#[derive(Debug, Clone, Copy)]
6#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
7pub struct Equipartition<const N: usize> {
8 lb: f64,
9 ub: f64,
10}
11
12impl<const N: usize> Equipartition<N> {
13 pub fn new(lb: f64, ub: f64) -> Equipartition<N> {
14 if N == 0 {
15 panic!("A partition must have a number partitions of 1 or greater.")
16 }
17
18 Equipartition { lb, ub, }
19 }
20
21 pub fn from_interval<I: Into<Interval>>(d: I) -> Equipartition<N> {
22 let interval = d.into();
23
24 Equipartition {
25 lb: interval.lb.expect("Must be a bounded interval."),
26 ub: interval.ub.expect("Must be a bounded interval."),
27 }
28 }
29
30 #[inline]
31 pub fn n_partitions(&self) -> usize { N }
32
33 #[inline]
34 pub fn partition_width(&self) -> f64 { (self.ub - self.lb) / N as f64 }
35
36 pub fn centres(&self) -> [f64; N] {
37 let w = self.partition_width();
38 let hw = w / 2.0;
39 let mut output = [f64::default(); N];
40
41 for i in 0..N {
42 output[i] = self.lb + w * ((i + 1) as f64) - hw;
43 }
44
45 output
46 }
47
48 pub fn edges(&self) -> [f64; N] {
49 let w = self.partition_width();
50 let mut output = [f64::default(); N];
51
52 for i in 0..N {
53 output[i] = self.lb + w * (i as f64);
54 }
55
56 output
57 }
58
59 pub fn to_partition(&self, val: f64) -> usize {
60 let clipped = clip!(self.lb, val, self.ub);
61
62 let diff = clipped - self.lb;
63 let range = self.ub - self.lb;
64
65 let i = ((N as f64) * diff / range).floor() as usize;
66
67 if i >= N { N - 1 } else { i }
68 }
69}
70
71impl<const N: usize> Space for Equipartition<N> {
72 const DIM: usize = 1;
73
74 type Value = usize;
75
76 fn card(&self) -> Card { Card::Finite(N) }
77
78 fn contains(&self, val: &usize) -> bool { *val < N }
79}
80
81impl<const N: usize> OrderedSpace for Equipartition<N> {
82 fn min(&self) -> Option<usize> { Some(0) }
83
84 fn max(&self) -> Option<usize> { Some(N - 1) }
85}
86
87impl<const N: usize> FiniteSpace for Equipartition<N> {
88 fn to_ordinal(&self) -> Range<Self::Value> { 0..N }
89}
90
91impl<const N: usize, const M: usize> cmp::PartialEq<Equipartition<M>> for Equipartition<N> {
92 fn eq(&self, other: &Equipartition<M>) -> bool {
93 N.eq(&M) && self.lb.eq(&other.lb) && self.ub.eq(&other.ub)
94 }
95}
96
97impl<const N: usize> fmt::Display for Equipartition<N> {
98 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99 match N {
100 n if n == 1 => write!(f, "{{{} = x0, x1 = {}}}", self.lb, self.ub),
101 n if n == 2 => write!(f, "{{{} = x0, x1, x2 = {}}}", self.lb, self.ub),
102 n => write!(f, "{{{} = x0, x1, ..., x{} = {}}}", self.lb, n, self.ub),
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[cfg(feature = "serialize")]
112 extern crate serde_test;
113 #[cfg(feature = "serialize")]
114 use self::serde_test::{assert_tokens, Token};
115
116 #[test]
117 fn test_from_interval() {
118 assert_eq!(
119 Equipartition::<5>::new(0.0, 5.0),
120 Equipartition::<5>::from_interval(Interval::bounded(0.0, 5.0))
121 );
122 }
123
124 #[test]
125 fn test_density() {
126 assert_eq!(Equipartition::<5>::new(0.0, 5.0).n_partitions(), 5);
127 assert_eq!(Equipartition::<10>::new(0.0, 5.0).n_partitions(), 10);
128 assert_eq!(Equipartition::<100>::new(-5.0, 5.0).n_partitions(), 100);
129 }
130
131 #[test]
132 fn test_partition_width() {
133 assert_eq!(Equipartition::<5>::new(0.0, 5.0).partition_width(), 1.0);
134 assert_eq!(Equipartition::<10>::new(0.0, 5.0).partition_width(), 0.5);
135 assert_eq!(Equipartition::<10>::new(-5.0, 5.0).partition_width(), 1.0);
136 }
137
138 #[test]
139 fn test_centres() {
140 assert_eq!(
141 Equipartition::new(0.0, 5.0).centres(),
142 [0.5, 1.5, 2.5, 3.5, 4.5]
143 );
144
145 assert_eq!(
146 Equipartition::new(-5.0, 5.0).centres(),
147 [-4.0, -2.0, 0.0, 2.0, 4.0]
148 );
149 }
150
151 #[test]
152 fn test_to_partition() {
153 let d = Equipartition::<6>::new(0.0, 5.0);
154
155 assert_eq!(d.to_partition(-1.0), 0);
156 assert_eq!(d.to_partition(0.0), 0);
157 assert_eq!(d.to_partition(1.0), 1);
158 assert_eq!(d.to_partition(2.0), 2);
159 assert_eq!(d.to_partition(3.0), 3);
160 assert_eq!(d.to_partition(4.0), 4);
161 assert_eq!(d.to_partition(5.0), 5);
162 assert_eq!(d.to_partition(6.0), 5);
163 }
164
165 #[test]
166 fn test_dim() {
167 assert_eq!(Equipartition::<1>::DIM, 1);
168 assert_eq!(Equipartition::<5>::DIM, 1);
169 assert_eq!(Equipartition::<10>::DIM, 1);
170 }
171
172 #[test]
173 fn test_card() {
174 fn check<const N: usize>(lb: f64, ub: f64) {
175 let d = Equipartition::<N>::new(lb, ub);
176
177 assert_eq!(d.card(), Card::Finite(N));
178 }
179
180 check::<5>(0.0, 5.0);
181 check::<5>(-5.0, 0.0);
182 check::<10>(-5.0, 5.0);
183 }
184
185 #[test]
186 fn test_to_ordinal() {
187 fn check<const N: usize>(lb: f64, ub: f64) {
188 let d = Equipartition::<N>::new(lb, ub);
189
190 assert_eq!(d.to_ordinal(), 0..N);
191 }
192
193 check::<5>(0.0, 5.0);
194 check::<5>(-5.0, 0.0);
195 check::<10>(-5.0, 5.0);
196 }
197
198 #[cfg(feature = "serialize")]
199 #[test]
200 fn test_serialisation() {
201 fn check(lb: f64, ub: f64, n_partitions: usize) {
202 let d = Equipartition::new(lb, ub, n_partitions);
203
204 assert_tokens(
205 &d,
206 &[
207 Token::Struct {
208 name: "Equipartition",
209 len: 3,
210 },
211 Token::Str("lb"),
212 Token::F64(lb),
213 Token::Str("ub"),
214 Token::F64(ub),
215 Token::Str("n_partitions"),
216 Token::U64(n_partitions as u64),
217 Token::StructEnd,
218 ],
219 );
220 }
221
222 check(0.0, 5.0, 5);
223 check(-5.0, 5.0, 10);
224 check(-5.0, 0.0, 5);
225 }
226}