ublox_sockets/
set.rs

1use super::{AnySocket, Error, Instant, Result, Socket, SocketRef, SocketType};
2use atat::atat_derive::AtatLen;
3use heapless::Vec;
4use serde::{Deserialize, Serialize};
5
6/// A handle, identifying a socket in a set.
7#[derive(
8    Debug,
9    Clone,
10    Copy,
11    PartialEq,
12    Eq,
13    PartialOrd,
14    AtatLen,
15    Ord,
16    hash32_derive::Hash32,
17    Default,
18    Serialize,
19    Deserialize,
20)]
21#[cfg_attr(feature = "defmt", derive(defmt::Format))]
22pub struct Handle(pub u8);
23
24/// An extensible set of sockets.
25#[derive(Default, Debug)]
26pub struct Set<const TIMER_HZ: u32, const N: usize, const L: usize> {
27    pub sockets: Vec<Option<Socket<TIMER_HZ, L>>, N>,
28}
29
30impl<const TIMER_HZ: u32, const N: usize, const L: usize> Set<TIMER_HZ, N, L> {
31    /// Create a socket set using the provided storage.
32    pub fn new() -> Set<TIMER_HZ, N, L> {
33        let mut sockets = Vec::new();
34        while sockets.len() < N {
35            sockets.push(None).ok();
36        }
37        Set { sockets }
38    }
39
40    /// Get the maximum number of sockets the set can hold
41    pub fn capacity(&self) -> usize {
42        N
43    }
44
45    /// Get the current number of initialized sockets, the set is holding
46    pub fn len(&self) -> usize {
47        self.sockets.iter().filter(|a| a.is_some()).count()
48    }
49
50    /// Check if the set is currently holding no active sockets
51    pub fn is_empty(&self) -> bool {
52        self.len() == 0
53    }
54
55    /// Get the type of a specific socket in the set.
56    ///
57    /// Returned as a [`SocketType`]
58    pub fn socket_type(&self, handle: Handle) -> Option<SocketType> {
59        if let Ok(index) = self.index_of(handle) {
60            if let Some(socket) = self.sockets.get(index) {
61                return socket.as_ref().map(|s| s.get_type());
62            }
63        }
64        None
65    }
66
67    /// Add a socket to the set with the reference count 1, and return its handle.
68    pub fn add<T>(&mut self, socket: T) -> Result<Handle>
69    where
70        T: Into<Socket<TIMER_HZ, L>>,
71    {
72        let socket = socket.into();
73        let handle = socket.handle();
74
75        debug!(
76            "[Socket Set] Adding: {} {:?} to: {:?}",
77            handle.0,
78            socket.get_type(),
79            self
80        );
81
82        if self.index_of(handle).is_ok() {
83            return Err(Error::DuplicateSocket);
84        }
85
86        self.sockets
87            .iter_mut()
88            .find(|s| s.is_none())
89            .ok_or(Error::SocketSetFull)?
90            .replace(socket);
91
92        Ok(handle)
93    }
94
95    /// Get a socket from the set by its handle, as mutable.
96    pub fn get<T: AnySocket<TIMER_HZ, L>>(&mut self, handle: Handle) -> Result<SocketRef<T>> {
97        let index = self.index_of(handle)?;
98
99        match self.sockets.get_mut(index).ok_or(Error::InvalidSocket)? {
100            Some(socket) => Ok(T::downcast(SocketRef::new(socket))?),
101            None => Err(Error::InvalidSocket),
102        }
103    }
104
105    /// Get the index of a given socket in the set.
106    fn index_of(&self, handle: Handle) -> Result<usize> {
107        self.sockets
108            .iter()
109            .position(|i| {
110                i.as_ref()
111                    .map(|s| s.handle().0 == handle.0)
112                    .unwrap_or(false)
113            })
114            .ok_or(Error::InvalidSocket)
115    }
116
117    /// Remove a socket from the set
118    pub fn remove(&mut self, handle: Handle) -> Result<()> {
119        let index = self.index_of(handle)?;
120        let item: &mut Option<Socket<TIMER_HZ, L>> =
121            self.sockets.get_mut(index).ok_or(Error::InvalidSocket)?;
122
123        debug!(
124            "[Socket Set] Removing socket! {} {:?}",
125            handle.0,
126            item.as_ref().map(|i| i.get_type())
127        );
128
129        item.take().ok_or(Error::InvalidSocket)?;
130        Ok(())
131    }
132
133    /// Prune the sockets in this set.
134    ///
135    /// All sockets are removed and dropped.
136    pub fn prune(&mut self) {
137        debug!("[Socket Set] Pruning: {:?}", self);
138        self.sockets.iter_mut().enumerate().for_each(|(_, slot)| {
139            slot.take();
140        })
141    }
142
143    pub fn recycle(&mut self, ts: Instant<TIMER_HZ>) -> bool {
144        let h = self.iter().find(|(_, s)| s.recycle(ts)).map(|(h, _)| h);
145        if h.is_none() {
146            return false;
147        }
148        self.remove(h.unwrap()).is_ok()
149    }
150
151    /// Iterate every socket in this set.
152    pub fn iter(&self) -> impl Iterator<Item = (Handle, &Socket<TIMER_HZ, L>)> {
153        self.sockets.iter().filter_map(|slot| {
154            if let Some(socket) = slot {
155                Some((Handle(socket.handle().0), socket))
156            } else {
157                None
158            }
159        })
160    }
161
162    /// Iterate every socket in this set, as SocketRef.
163    pub fn iter_mut(&mut self) -> impl Iterator<Item = (Handle, SocketRef<Socket<TIMER_HZ, L>>)> {
164        self.sockets.iter_mut().filter_map(|slot| {
165            if let Some(socket) = slot {
166                Some((Handle(socket.handle().0), SocketRef::new(socket)))
167            } else {
168                None
169            }
170        })
171    }
172}
173
174#[cfg(feature = "defmt")]
175impl<const TIMER_HZ: u32, const N: usize, const L: usize> defmt::Format for Set<TIMER_HZ, N, L> {
176    fn format(&self, fmt: defmt::Formatter) {
177        defmt::write!(fmt, "[");
178        for socket in self.iter() {
179            match socket.1 {
180                Socket::Udp(s) => defmt::write!(fmt, "[{:?}, UDP({:?})],", socket.0, s.state()),
181                Socket::Tcp(s) => defmt::write!(fmt, "[{:?}, TCP({:?})],", socket.0, s.state()),
182            }
183        }
184        defmt::write!(fmt, "]");
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::{TcpSocket, UdpSocket};
192
193    use fugit::{ExtU32, MillisDurationU32};
194    use fugit_timer::Timer;
195    use std::convert::Infallible;
196
197    const TIMER_HZ: u32 = 1000;
198
199    pub struct MockTimer {
200        monotonic: std::time::Instant,
201        start: Option<std::time::Instant>,
202        duration: MillisDurationU32,
203    }
204
205    impl MockTimer {
206        pub fn new() -> MockTimer {
207            MockTimer {
208                monotonic: std::time::Instant::now(),
209                start: None,
210                duration: MillisDurationU32::millis(0),
211            }
212        }
213    }
214
215    impl Timer<TIMER_HZ> for MockTimer {
216        type Error = Infallible;
217
218        fn now(&mut self) -> fugit::TimerInstantU32<TIMER_HZ> {
219            let millis = self.monotonic.elapsed().as_millis();
220            fugit::TimerInstantU32::from_ticks(millis as u32)
221        }
222
223        fn start(
224            &mut self,
225            duration: fugit::TimerDurationU32<TIMER_HZ>,
226        ) -> std::result::Result<(), Self::Error> {
227            self.start = Some(std::time::Instant::now());
228            self.duration = duration.convert();
229            Ok(())
230        }
231
232        fn cancel(&mut self) -> std::result::Result<(), Self::Error> {
233            if self.start.is_some() {
234                self.start = None;
235            }
236            Ok(())
237        }
238
239        fn wait(&mut self) -> nb::Result<(), Self::Error> {
240            if let Some(start) = self.start {
241                let now = std::time::Instant::now();
242                if now - start > std::time::Duration::from_millis(self.duration.ticks() as u64) {
243                    Ok(())
244                } else {
245                    std::thread::sleep(std::time::Duration::from_millis(1));
246                    Err(nb::Error::WouldBlock)
247                }
248            } else {
249                Ok(())
250            }
251        }
252    }
253
254    #[test]
255    fn mock_timer_works() {
256        let now = std::time::Instant::now();
257
258        let mut timer = MockTimer::new();
259        timer.start(1000.millis()).unwrap();
260        //timer.start(1.secs::<1, 1000>().convert()).unwrap();
261        nb::block!(timer.wait()).unwrap();
262        assert!(now.elapsed().as_millis() >= 1_000);
263    }
264
265    #[test]
266    fn add_socket() {
267        let mut set = Set::<TIMER_HZ, 2, 64>::new();
268
269        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
270        assert_eq!(set.len(), 1);
271        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
272        assert_eq!(set.len(), 2);
273    }
274
275    #[test]
276    fn remove_socket() {
277        let mut set = Set::<TIMER_HZ, 2, 64>::new();
278
279        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
280        assert_eq!(set.len(), 1);
281        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
282        assert_eq!(set.len(), 2);
283
284        assert!(set.remove(Handle(0)).is_ok());
285        assert_eq!(set.len(), 1);
286
287        assert!(set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0)).is_err());
288
289        set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
290            .expect("failed to get udp socket");
291    }
292
293    #[test]
294    fn add_duplicate_socket() {
295        let mut set = Set::<TIMER_HZ, 2, 64>::new();
296
297        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
298        assert_eq!(set.len(), 1);
299        assert_eq!(set.add(UdpSocket::new(0)), Err(Error::DuplicateSocket));
300    }
301
302    #[test]
303    fn add_socket_to_full_set() {
304        let mut set = Set::<TIMER_HZ, 2, 64>::new();
305
306        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
307        assert_eq!(set.len(), 1);
308        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
309        assert_eq!(set.len(), 2);
310        assert_eq!(set.add(UdpSocket::new(2)), Err(Error::SocketSetFull));
311    }
312
313    #[test]
314    fn get_socket() {
315        let mut set = Set::<TIMER_HZ, 2, 64>::new();
316
317        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
318        assert_eq!(set.len(), 1);
319        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
320        assert_eq!(set.len(), 2);
321
322        set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0))
323            .expect("failed to get tcp socket");
324
325        set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
326            .expect("failed to get udp socket");
327    }
328
329    #[test]
330    fn get_socket_wrong_type() {
331        let mut set = Set::<TIMER_HZ, 2, 64>::new();
332
333        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
334        assert_eq!(set.len(), 1);
335        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
336        assert_eq!(set.len(), 2);
337
338        assert!(set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(1)).is_err());
339
340        set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
341            .expect("failed to get udp socket");
342    }
343
344    #[test]
345    fn get_socket_type() {
346        let mut set = Set::<TIMER_HZ, 2, 64>::new();
347
348        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
349        assert_eq!(set.len(), 1);
350        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
351        assert_eq!(set.len(), 2);
352
353        assert_eq!(set.socket_type(Handle(0)), Some(SocketType::Tcp));
354        assert_eq!(set.socket_type(Handle(1)), Some(SocketType::Udp));
355    }
356
357    #[test]
358    fn replace_socket() {
359        let mut set = Set::<TIMER_HZ, 2, 64>::new();
360
361        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
362        assert_eq!(set.len(), 1);
363        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
364        assert_eq!(set.len(), 2);
365
366        assert!(set.remove(Handle(0)).is_ok());
367        assert_eq!(set.len(), 1);
368
369        assert!(set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0)).is_err());
370
371        set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
372            .expect("failed to get udp socket");
373
374        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
375        assert_eq!(set.len(), 2);
376
377        set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0))
378            .expect("failed to get tcp socket");
379    }
380
381    #[test]
382    fn prune_socket_set() {
383        let mut set = Set::<TIMER_HZ, 2, 64>::new();
384
385        assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
386        assert_eq!(set.len(), 1);
387        assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
388        assert_eq!(set.len(), 2);
389
390        set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0))
391            .expect("failed to get tcp socket");
392
393        set.prune();
394        assert_eq!(set.len(), 0);
395    }
396}