rust_rocket/
track.rs

1//! [`Key`] and [`Track`] types.
2
3use crate::interpolation::*;
4
5/// The `Key` Type.
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))]
8#[derive(Debug, Clone, Copy)]
9pub struct Key {
10    row: u32,
11    value: f32,
12    interpolation: Interpolation,
13}
14
15impl Key {
16    /// Construct a new `Key`.
17    pub fn new(row: u32, value: f32, interp: Interpolation) -> Key {
18        Key {
19            row,
20            value,
21            interpolation: interp,
22        }
23    }
24}
25
26/// The `Track` Type. This is a collection of `Key`s with a name.
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28#[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))]
29#[derive(Debug, Clone)]
30pub struct Track {
31    name: String,
32    keys: Vec<Key>,
33}
34
35impl Track {
36    /// Construct a new Track with a name.
37    pub fn new<S: Into<String>>(name: S) -> Track {
38        Track {
39            name: name.into(),
40            keys: Vec::new(),
41        }
42    }
43
44    /// Get the name of the track.
45    pub fn get_name(&self) -> &str {
46        self.name.as_str()
47    }
48
49    fn get_exact_position(&self, row: u32) -> Option<usize> {
50        self.keys.iter().position(|k| k.row == row)
51    }
52
53    fn get_insert_position(&self, row: u32) -> Option<usize> {
54        self.keys.iter().position(|k| k.row >= row)
55    }
56
57    fn get_lower_bound_position(&self, row: u32) -> usize {
58        self.keys
59            .iter()
60            .position(|k| k.row > row)
61            .unwrap_or(self.keys.len())
62            - 1
63    }
64
65    /// Insert or update a key on a track.
66    pub fn set_key(&mut self, key: Key) {
67        if let Some(pos) = self.get_exact_position(key.row) {
68            self.keys[pos] = key;
69        } else if let Some(pos) = self.get_insert_position(key.row) {
70            self.keys.insert(pos, key);
71        } else {
72            self.keys.push(key);
73        }
74    }
75
76    /// Delete a key from a track.
77    ///
78    /// If a key does not exist this will do nothing.
79    pub fn delete_key(&mut self, row: u32) {
80        if let Some(pos) = self.get_exact_position(row) {
81            self.keys.remove(pos);
82        }
83    }
84
85    /// Get a value based on a row.
86    ///
87    /// The row can be between two integers.
88    /// This will perform the required interpolation.
89    pub fn get_value(&self, row: f32) -> f32 {
90        if self.keys.is_empty() {
91            return 0.0;
92        }
93
94        let lower_row = row.floor() as u32;
95
96        if lower_row <= self.keys[0].row {
97            return self.keys[0].value;
98        }
99
100        if lower_row >= self.keys[self.keys.len() - 1].row {
101            return self.keys[self.keys.len() - 1].value;
102        }
103
104        let pos = self.get_lower_bound_position(lower_row);
105
106        let lower = &self.keys[pos];
107        let higher = &self.keys[pos + 1];
108
109        let t = (row - (lower.row as f32)) / ((higher.row as f32) - (lower.row as f32));
110        let it = lower.interpolation.interpolate(t);
111
112        lower.value + (higher.value - lower.value) * it
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    fn test_track() -> Track {
121        let mut track = Track::new("test");
122        track.set_key(Key::new(0, 1.0, Interpolation::Step));
123        track.set_key(Key::new(5, 0.0, Interpolation::Step));
124        track.set_key(Key::new(10, 1.0, Interpolation::Linear));
125        track.set_key(Key::new(20, 2.0, Interpolation::Linear));
126        track
127    }
128
129    fn assert_test_track(track: &Track) {
130        assert_eq!(track.get_value(-1.), 1.0);
131        assert_eq!(track.get_value(0.), 1.0);
132        assert_eq!(track.get_value(1.), 1.0);
133
134        assert_eq!(track.get_value(4.), 1.0);
135        assert_eq!(track.get_value(5.), 0.0);
136        assert_eq!(track.get_value(6.), 0.0);
137
138        assert_eq!(track.get_value(9.), 0.0);
139        assert_eq!(track.get_value(10.), 1.0);
140
141        assert!((track.get_value(15.) - 1.5).abs() <= f32::EPSILON);
142        assert_eq!(track.get_value(21.), 2.0);
143    }
144
145    #[test]
146    fn test_keys() {
147        let track = test_track();
148        assert_test_track(&track);
149    }
150
151    #[test]
152    #[cfg(feature = "bincode")]
153    fn test_bincode_roundtrip() {
154        let track = test_track();
155
156        let bincode_conf = bincode::config::standard();
157        let bytes = bincode::encode_to_vec(track, bincode_conf).unwrap();
158        let (decoded_track, _): (Track, usize) =
159            bincode::decode_from_slice(&bytes, bincode_conf).unwrap();
160        assert_test_track(&decoded_track);
161    }
162}