Skip to main content

thrust/
intervals.rs

1use std::cmp::min;
2use std::fmt;
3use std::fmt::Display;
4use std::iter::Sum;
5use std::ops::{Add, BitAnd, Sub};
6
7#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
8pub struct Interval<T> {
9    pub start: T,
10    pub stop: T,
11}
12
13impl<T> Display for &Interval<T>
14where
15    T: Display,
16{
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        write!(f, "[{}, {}]", self.start, self.stop)
19    }
20}
21
22#[derive(Debug)]
23pub struct IntervalCollection<T> {
24    pub elts: Vec<Interval<T>>,
25}
26
27impl<T> Display for &IntervalCollection<T>
28where
29    T: Display,
30{
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "[")?;
33        for (i, elt) in self.elts.iter().enumerate() {
34            if i > 0 {
35                write!(f, ", ")?;
36            }
37            write!(f, "{elt}")?;
38        }
39        write!(f, "]")
40    }
41}
42
43impl<T> Add for &Interval<T>
44where
45    T: Ord + Copy,
46{
47    type Output = IntervalCollection<T>;
48    fn add(self, other: &Interval<T>) -> IntervalCollection<T> {
49        let left = IntervalCollection { elts: vec![*self] };
50        let right = IntervalCollection { elts: vec![*other] };
51        &left + &right
52    }
53}
54
55impl<T> Add for Interval<T>
56where
57    T: Ord + Copy,
58{
59    type Output = IntervalCollection<T>;
60    fn add(self, other: Interval<T>) -> IntervalCollection<T> {
61        let left = IntervalCollection { elts: vec![self] };
62        let right = IntervalCollection { elts: vec![other] };
63        &left + &right
64    }
65}
66
67impl<T> Add<IntervalCollection<T>> for &Interval<T>
68where
69    T: Ord + Copy,
70{
71    type Output = IntervalCollection<T>;
72    fn add(self, other: IntervalCollection<T>) -> IntervalCollection<T> {
73        let left = IntervalCollection { elts: vec![*self] };
74        &left + &other
75    }
76}
77
78impl<T> Add<&IntervalCollection<T>> for &Interval<T>
79where
80    T: Ord + Copy,
81{
82    type Output = IntervalCollection<T>;
83    fn add(self, other: &IntervalCollection<T>) -> IntervalCollection<T> {
84        let left = IntervalCollection { elts: vec![*self] };
85        &left + other
86    }
87}
88
89impl<T> Add<&Interval<T>> for &IntervalCollection<T>
90where
91    T: Ord + Copy,
92{
93    type Output = IntervalCollection<T>;
94    fn add(self, other: &Interval<T>) -> IntervalCollection<T> {
95        let right = IntervalCollection { elts: vec![*other] };
96        self + &right
97    }
98}
99
100impl<T> Add<Interval<T>> for IntervalCollection<T>
101where
102    T: Ord + Copy,
103{
104    type Output = IntervalCollection<T>;
105    fn add(self, other: Interval<T>) -> IntervalCollection<T> {
106        let right = IntervalCollection { elts: vec![other] };
107        self + right
108    }
109}
110
111impl<T> Add<&Interval<T>> for IntervalCollection<T>
112where
113    T: Ord + Copy,
114{
115    type Output = IntervalCollection<T>;
116    fn add(self, other: &Interval<T>) -> IntervalCollection<T> {
117        let right = IntervalCollection { elts: vec![*other] };
118        self + right
119    }
120}
121
122impl<T> Add for &IntervalCollection<T>
123where
124    T: Ord + Copy,
125{
126    type Output = IntervalCollection<T>;
127    fn add(self, other: &IntervalCollection<T>) -> IntervalCollection<T> {
128        let mut elts = Vec::new();
129        let mut start = min(&self.elts[0], &other.elts[0]);
130
131        loop {
132            let swiping_line = start.start;
133            let mut horizon = start.stop;
134
135            horizon = self
136                .elts
137                .iter()
138                .chain(other.elts.iter())
139                .filter(|elt| swiping_line <= elt.start && elt.start <= horizon)
140                .map(|elt| elt.stop)
141                .max()
142                .expect("Unexpected error");
143
144            loop {
145                match self
146                    .elts
147                    .iter()
148                    .chain(other.elts.iter())
149                    .filter(|elt| elt.start <= horizon && horizon < elt.stop)
150                    .map(|elt| elt.stop)
151                    .max()
152                {
153                    None => break,
154                    Some(x) => horizon = x,
155                }
156            }
157            elts.push(Interval {
158                start: swiping_line,
159                stop: horizon,
160            });
161            match self
162                .elts
163                .iter()
164                .chain(other.elts.iter())
165                .filter(|elt| elt.start > horizon)
166                .min()
167            {
168                None => break,
169                Some(x) => start = x,
170            }
171        }
172
173        IntervalCollection { elts }
174    }
175}
176
177impl<T> Add for IntervalCollection<T>
178where
179    T: Ord + Copy,
180{
181    type Output = IntervalCollection<T>;
182    fn add(self, other: IntervalCollection<T>) -> IntervalCollection<T> {
183        &self + &other
184    }
185}
186impl<T, Delta> Sub for Interval<T>
187where
188    T: Sub<T, Output = Delta> + Add<Delta, Output = T> + Copy + PartialOrd,
189    Delta: Copy,
190{
191    type Output = IntervalCollection<T>;
192    fn sub(self, other: Interval<T>) -> IntervalCollection<T> {
193        let mut elts = Vec::new();
194        if self.overlap(&other) {
195            if other.start > self.start {
196                elts.push(Interval {
197                    start: self.start,
198                    stop: other.start,
199                })
200            }
201            if other.stop < self.stop {
202                elts.push(Interval {
203                    start: other.stop,
204                    stop: self.stop,
205                })
206            }
207        } else {
208            elts.push(self)
209        }
210        IntervalCollection { elts }
211    }
212}
213
214impl<T, Delta> Sub<Interval<T>> for IntervalCollection<T>
215where
216    T: Sub<T, Output = Delta> + Add<Delta, Output = T> + Copy + PartialOrd,
217    Delta: Copy,
218{
219    type Output = IntervalCollection<T>;
220    fn sub(self, other: Interval<T>) -> IntervalCollection<T> {
221        let mut elts = Vec::new();
222        for elt in self.elts {
223            if elt.overlap(&other) {
224                if other.start > elt.start {
225                    elts.push(Interval {
226                        start: elt.start,
227                        stop: other.start,
228                    })
229                }
230                if other.stop < elt.stop {
231                    elts.push(Interval {
232                        start: other.stop,
233                        stop: elt.stop,
234                    })
235                }
236            } else {
237                elts.push(elt)
238            }
239        }
240        IntervalCollection { elts }
241    }
242}
243
244impl<T, Delta> Sub for IntervalCollection<T>
245where
246    T: Sub<T, Output = Delta> + Add<Delta, Output = T> + Copy + PartialOrd,
247    Delta: Copy,
248{
249    type Output = IntervalCollection<T>;
250    fn sub(self, other: IntervalCollection<T>) -> IntervalCollection<T> {
251        let mut res = self;
252        for elt in other.elts {
253            res = res - elt;
254        }
255        res
256    }
257}
258
259/* Implement intersection between two Intervals */
260impl<T> BitAnd for &Interval<T>
261where
262    T: Copy + Clone + PartialEq + PartialOrd,
263{
264    type Output = Option<Interval<T>>;
265    fn bitand(self, other: &Interval<T>) -> Option<Interval<T>> {
266        match self.overlap(other) {
267            true => Some(Interval {
268                start: match self.start > other.start {
269                    true => self.start,
270                    false => other.start,
271                },
272                stop: match self.stop < other.stop {
273                    true => self.stop,
274                    false => other.stop,
275                },
276            }),
277            false => None,
278        }
279    }
280}
281impl<T> BitAnd<&IntervalCollection<T>> for &Interval<T>
282where
283    T: Copy + Clone + PartialEq + PartialOrd,
284{
285    type Output = IntervalCollection<T>;
286    fn bitand(self, other: &IntervalCollection<T>) -> IntervalCollection<T> {
287        let mut elts = Vec::<Interval<T>>::with_capacity(other.elts.len());
288        for interval in &other.elts {
289            match self & interval {
290                None => (),
291                Some(i) => elts.push(i),
292            }
293        }
294        IntervalCollection { elts }
295    }
296}
297
298impl<T> BitAnd<&Interval<T>> for &IntervalCollection<T>
299where
300    T: Copy + Clone + PartialEq + PartialOrd,
301{
302    type Output = IntervalCollection<T>;
303    fn bitand(self, other: &Interval<T>) -> IntervalCollection<T> {
304        other & self
305    }
306}
307
308impl<T> BitAnd for &IntervalCollection<T>
309where
310    T: Copy + Clone + PartialEq + PartialOrd,
311{
312    type Output = IntervalCollection<T>;
313    fn bitand(self, other: &IntervalCollection<T>) -> IntervalCollection<T> {
314        let mut elts = Vec::<Interval<T>>::with_capacity(self.elts.len());
315        for interval in &other.elts {
316            let r = self & interval;
317            elts.extend(r.elts)
318        }
319        IntervalCollection { elts }
320    }
321}
322
323impl<T, Delta> Interval<T>
324where
325    T: Sub<T, Output = Delta> + Add<Delta, Output = T> + Copy,
326    Delta: Copy,
327{
328    pub fn duration(self) -> Delta {
329        self.stop - self.start
330    }
331    pub fn shift(&self, delta: Delta) -> Interval<T> {
332        Interval {
333            start: self.start + delta,
334            stop: self.stop + delta,
335        }
336    }
337}
338
339impl<T> Interval<T>
340where
341    T: PartialOrd,
342{
343    pub fn overlap(&self, other: &Interval<T>) -> bool {
344        self.start < other.stop && self.stop > other.start
345    }
346}
347
348impl<T, Delta> IntervalCollection<T>
349where
350    T: Sub<T, Output = Delta> + Add<Delta, Output = T> + Copy + PartialOrd,
351    Delta: Copy + Sum,
352{
353    pub fn total_duration(&self) -> Delta {
354        self.elts.iter().map(|elt| elt.duration()).sum()
355    }
356}
357
358#[cfg(test)]
359mod tests {
360
361    use super::Interval;
362    use jiff::{Timestamp, ToSpan};
363
364    static I1: Interval<i32> = Interval { start: 0, stop: 1 };
365    static I2: Interval<i32> = Interval { start: 1, stop: 2 };
366    static I3: Interval<i32> = Interval { start: 2, stop: 3 };
367    static I4: Interval<i32> = Interval { start: 3, stop: 4 };
368    static I5: Interval<i32> = Interval { start: 4, stop: 5 };
369
370    #[test]
371    fn interval_i32() {
372        assert_eq!(I1.duration(), 1);
373        let shifted = I1.shift(1);
374        assert_eq!(shifted.duration(), 1);
375        assert_ne!(shifted, I1);
376        assert_eq!(shifted, I2);
377        assert_eq!(format!("{:?}", &shifted), "Interval { start: 1, stop: 2 }");
378        assert_eq!(format!("{:}", &shifted), "[1, 2]");
379    }
380
381    #[test]
382    fn interval_dt() {
383        let i_dt: Interval<Timestamp> = Interval {
384            start: "2024-01-20T12:00:00Z".parse().expect("error date"),
385            stop: "2024-01-20T13:00:00Z".parse().expect("error date"),
386        };
387        assert_eq!(i_dt.duration().compare(1.hour()).unwrap(), std::cmp::Ordering::Equal);
388        assert_eq!(
389            i_dt.shift(5.hour()).duration().compare(1.hour()).unwrap(),
390            std::cmp::Ordering::Equal
391        );
392    }
393    #[test]
394    fn intervals_consistent() {
395        assert_eq!(
396            format!("{:?}", I1 + I2),
397            "IntervalCollection { elts: [Interval { start: 0, stop: 2 }] }"
398        );
399        assert_eq!(format!("{:}", &(I1 + I2)), "[[0, 2]]");
400        assert_eq!(format!("{:}", &(I1 + I3)), "[[0, 1], [2, 3]]");
401        assert_eq!(format!("{:}", &(I2 + I4)), "[[1, 2], [3, 4]]");
402        let s1 = (I1 + I3) + (I2 + I4);
403        assert_eq!(format!("{:}", &s1), "[[0, 4]]");
404        let s2 = (I1 + I3) + (I4 + I5);
405        assert_eq!(format!("{:}", &s2), "[[0, 1], [2, 5]]");
406        let s3 = I1 + I3 + I4 + I5;
407        assert_eq!(format!("{:}", &s3), "[[0, 1], [2, 5]]");
408
409        let i1: Interval<Timestamp> = Interval {
410            start: "2024-01-20T12:00:00Z".parse().expect("error date"),
411            stop: "2024-01-20T13:00:00Z".parse().expect("error date"),
412        };
413        let i2 = Interval {
414            start: "2024-01-20T13:00:00Z".parse().expect("error date"),
415            stop: "2024-01-20T14:00:00Z".parse().expect("error date"),
416        };
417        assert_eq!(
418            format!("{:}", &(i1 + i2)),
419            "[[2024-01-20T12:00:00Z, 2024-01-20T14:00:00Z]]"
420        );
421    }
422
423    #[test]
424    fn intervals_sub() {
425        assert_eq!(format!("{:}", &(I1 - I2)), "[[0, 1]]");
426        assert_eq!(format!("{:}", &(Interval { start: 0, stop: 2 } - I2)), "[[0, 1]]");
427        assert_eq!(format!("{:}", &((I1 + I2 + I3) - I2)), "[[0, 1], [2, 3]]");
428        assert_eq!(format!("{:}", &((I1 + I2) - (I3 + I2))), "[[0, 1]]");
429        assert_eq!(
430            format!("{:}", &(((I1 + I2) + (I2 + I3) + I5) - (I2 + I3))),
431            "[[0, 1], [4, 5]]"
432        );
433    }
434}