1use super::{AnySocket, Error, Instant, Result, Socket, SocketRef, SocketType};
2use atat::atat_derive::AtatLen;
3use heapless::Vec;
4use serde::{Deserialize, Serialize};
5
6#[derive(
8 Debug,
9 Clone,
10 Copy,
11 PartialEq,
12 Eq,
13 PartialOrd,
14 AtatLen,
15 Ord,
16 hash32_derive::Hash32,
17 Default,
18 Serialize,
19 Deserialize,
20)]
21#[cfg_attr(feature = "defmt", derive(defmt::Format))]
22pub struct Handle(pub u8);
23
24#[derive(Default, Debug)]
26pub struct Set<const TIMER_HZ: u32, const N: usize, const L: usize> {
27 pub sockets: Vec<Option<Socket<TIMER_HZ, L>>, N>,
28}
29
30impl<const TIMER_HZ: u32, const N: usize, const L: usize> Set<TIMER_HZ, N, L> {
31 pub fn new() -> Set<TIMER_HZ, N, L> {
33 let mut sockets = Vec::new();
34 while sockets.len() < N {
35 sockets.push(None).ok();
36 }
37 Set { sockets }
38 }
39
40 pub fn capacity(&self) -> usize {
42 N
43 }
44
45 pub fn len(&self) -> usize {
47 self.sockets.iter().filter(|a| a.is_some()).count()
48 }
49
50 pub fn is_empty(&self) -> bool {
52 self.len() == 0
53 }
54
55 pub fn socket_type(&self, handle: Handle) -> Option<SocketType> {
59 if let Ok(index) = self.index_of(handle) {
60 if let Some(socket) = self.sockets.get(index) {
61 return socket.as_ref().map(|s| s.get_type());
62 }
63 }
64 None
65 }
66
67 pub fn add<T>(&mut self, socket: T) -> Result<Handle>
69 where
70 T: Into<Socket<TIMER_HZ, L>>,
71 {
72 let socket = socket.into();
73 let handle = socket.handle();
74
75 debug!(
76 "[Socket Set] Adding: {} {:?} to: {:?}",
77 handle.0,
78 socket.get_type(),
79 self
80 );
81
82 if self.index_of(handle).is_ok() {
83 return Err(Error::DuplicateSocket);
84 }
85
86 self.sockets
87 .iter_mut()
88 .find(|s| s.is_none())
89 .ok_or(Error::SocketSetFull)?
90 .replace(socket);
91
92 Ok(handle)
93 }
94
95 pub fn get<T: AnySocket<TIMER_HZ, L>>(&mut self, handle: Handle) -> Result<SocketRef<T>> {
97 let index = self.index_of(handle)?;
98
99 match self.sockets.get_mut(index).ok_or(Error::InvalidSocket)? {
100 Some(socket) => Ok(T::downcast(SocketRef::new(socket))?),
101 None => Err(Error::InvalidSocket),
102 }
103 }
104
105 fn index_of(&self, handle: Handle) -> Result<usize> {
107 self.sockets
108 .iter()
109 .position(|i| {
110 i.as_ref()
111 .map(|s| s.handle().0 == handle.0)
112 .unwrap_or(false)
113 })
114 .ok_or(Error::InvalidSocket)
115 }
116
117 pub fn remove(&mut self, handle: Handle) -> Result<()> {
119 let index = self.index_of(handle)?;
120 let item: &mut Option<Socket<TIMER_HZ, L>> =
121 self.sockets.get_mut(index).ok_or(Error::InvalidSocket)?;
122
123 debug!(
124 "[Socket Set] Removing socket! {} {:?}",
125 handle.0,
126 item.as_ref().map(|i| i.get_type())
127 );
128
129 item.take().ok_or(Error::InvalidSocket)?;
130 Ok(())
131 }
132
133 pub fn prune(&mut self) {
137 debug!("[Socket Set] Pruning: {:?}", self);
138 self.sockets.iter_mut().enumerate().for_each(|(_, slot)| {
139 slot.take();
140 })
141 }
142
143 pub fn recycle(&mut self, ts: Instant<TIMER_HZ>) -> bool {
144 let h = self.iter().find(|(_, s)| s.recycle(ts)).map(|(h, _)| h);
145 if h.is_none() {
146 return false;
147 }
148 self.remove(h.unwrap()).is_ok()
149 }
150
151 pub fn iter(&self) -> impl Iterator<Item = (Handle, &Socket<TIMER_HZ, L>)> {
153 self.sockets.iter().filter_map(|slot| {
154 if let Some(socket) = slot {
155 Some((Handle(socket.handle().0), socket))
156 } else {
157 None
158 }
159 })
160 }
161
162 pub fn iter_mut(&mut self) -> impl Iterator<Item = (Handle, SocketRef<Socket<TIMER_HZ, L>>)> {
164 self.sockets.iter_mut().filter_map(|slot| {
165 if let Some(socket) = slot {
166 Some((Handle(socket.handle().0), SocketRef::new(socket)))
167 } else {
168 None
169 }
170 })
171 }
172}
173
174#[cfg(feature = "defmt")]
175impl<const TIMER_HZ: u32, const N: usize, const L: usize> defmt::Format for Set<TIMER_HZ, N, L> {
176 fn format(&self, fmt: defmt::Formatter) {
177 defmt::write!(fmt, "[");
178 for socket in self.iter() {
179 match socket.1 {
180 Socket::Udp(s) => defmt::write!(fmt, "[{:?}, UDP({:?})],", socket.0, s.state()),
181 Socket::Tcp(s) => defmt::write!(fmt, "[{:?}, TCP({:?})],", socket.0, s.state()),
182 }
183 }
184 defmt::write!(fmt, "]");
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::{TcpSocket, UdpSocket};
192
193 use fugit::{ExtU32, MillisDurationU32};
194 use fugit_timer::Timer;
195 use std::convert::Infallible;
196
197 const TIMER_HZ: u32 = 1000;
198
199 pub struct MockTimer {
200 monotonic: std::time::Instant,
201 start: Option<std::time::Instant>,
202 duration: MillisDurationU32,
203 }
204
205 impl MockTimer {
206 pub fn new() -> MockTimer {
207 MockTimer {
208 monotonic: std::time::Instant::now(),
209 start: None,
210 duration: MillisDurationU32::millis(0),
211 }
212 }
213 }
214
215 impl Timer<TIMER_HZ> for MockTimer {
216 type Error = Infallible;
217
218 fn now(&mut self) -> fugit::TimerInstantU32<TIMER_HZ> {
219 let millis = self.monotonic.elapsed().as_millis();
220 fugit::TimerInstantU32::from_ticks(millis as u32)
221 }
222
223 fn start(
224 &mut self,
225 duration: fugit::TimerDurationU32<TIMER_HZ>,
226 ) -> std::result::Result<(), Self::Error> {
227 self.start = Some(std::time::Instant::now());
228 self.duration = duration.convert();
229 Ok(())
230 }
231
232 fn cancel(&mut self) -> std::result::Result<(), Self::Error> {
233 if self.start.is_some() {
234 self.start = None;
235 }
236 Ok(())
237 }
238
239 fn wait(&mut self) -> nb::Result<(), Self::Error> {
240 if let Some(start) = self.start {
241 let now = std::time::Instant::now();
242 if now - start > std::time::Duration::from_millis(self.duration.ticks() as u64) {
243 Ok(())
244 } else {
245 std::thread::sleep(std::time::Duration::from_millis(1));
246 Err(nb::Error::WouldBlock)
247 }
248 } else {
249 Ok(())
250 }
251 }
252 }
253
254 #[test]
255 fn mock_timer_works() {
256 let now = std::time::Instant::now();
257
258 let mut timer = MockTimer::new();
259 timer.start(1000.millis()).unwrap();
260 nb::block!(timer.wait()).unwrap();
262 assert!(now.elapsed().as_millis() >= 1_000);
263 }
264
265 #[test]
266 fn add_socket() {
267 let mut set = Set::<TIMER_HZ, 2, 64>::new();
268
269 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
270 assert_eq!(set.len(), 1);
271 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
272 assert_eq!(set.len(), 2);
273 }
274
275 #[test]
276 fn remove_socket() {
277 let mut set = Set::<TIMER_HZ, 2, 64>::new();
278
279 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
280 assert_eq!(set.len(), 1);
281 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
282 assert_eq!(set.len(), 2);
283
284 assert!(set.remove(Handle(0)).is_ok());
285 assert_eq!(set.len(), 1);
286
287 assert!(set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0)).is_err());
288
289 set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
290 .expect("failed to get udp socket");
291 }
292
293 #[test]
294 fn add_duplicate_socket() {
295 let mut set = Set::<TIMER_HZ, 2, 64>::new();
296
297 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
298 assert_eq!(set.len(), 1);
299 assert_eq!(set.add(UdpSocket::new(0)), Err(Error::DuplicateSocket));
300 }
301
302 #[test]
303 fn add_socket_to_full_set() {
304 let mut set = Set::<TIMER_HZ, 2, 64>::new();
305
306 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
307 assert_eq!(set.len(), 1);
308 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
309 assert_eq!(set.len(), 2);
310 assert_eq!(set.add(UdpSocket::new(2)), Err(Error::SocketSetFull));
311 }
312
313 #[test]
314 fn get_socket() {
315 let mut set = Set::<TIMER_HZ, 2, 64>::new();
316
317 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
318 assert_eq!(set.len(), 1);
319 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
320 assert_eq!(set.len(), 2);
321
322 set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0))
323 .expect("failed to get tcp socket");
324
325 set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
326 .expect("failed to get udp socket");
327 }
328
329 #[test]
330 fn get_socket_wrong_type() {
331 let mut set = Set::<TIMER_HZ, 2, 64>::new();
332
333 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
334 assert_eq!(set.len(), 1);
335 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
336 assert_eq!(set.len(), 2);
337
338 assert!(set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(1)).is_err());
339
340 set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
341 .expect("failed to get udp socket");
342 }
343
344 #[test]
345 fn get_socket_type() {
346 let mut set = Set::<TIMER_HZ, 2, 64>::new();
347
348 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
349 assert_eq!(set.len(), 1);
350 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
351 assert_eq!(set.len(), 2);
352
353 assert_eq!(set.socket_type(Handle(0)), Some(SocketType::Tcp));
354 assert_eq!(set.socket_type(Handle(1)), Some(SocketType::Udp));
355 }
356
357 #[test]
358 fn replace_socket() {
359 let mut set = Set::<TIMER_HZ, 2, 64>::new();
360
361 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
362 assert_eq!(set.len(), 1);
363 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
364 assert_eq!(set.len(), 2);
365
366 assert!(set.remove(Handle(0)).is_ok());
367 assert_eq!(set.len(), 1);
368
369 assert!(set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0)).is_err());
370
371 set.get::<UdpSocket<TIMER_HZ, 64>>(Handle(1))
372 .expect("failed to get udp socket");
373
374 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
375 assert_eq!(set.len(), 2);
376
377 set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0))
378 .expect("failed to get tcp socket");
379 }
380
381 #[test]
382 fn prune_socket_set() {
383 let mut set = Set::<TIMER_HZ, 2, 64>::new();
384
385 assert_eq!(set.add(TcpSocket::new(0)), Ok(Handle(0)));
386 assert_eq!(set.len(), 1);
387 assert_eq!(set.add(UdpSocket::new(1)), Ok(Handle(1)));
388 assert_eq!(set.len(), 2);
389
390 set.get::<TcpSocket<TIMER_HZ, 64>>(Handle(0))
391 .expect("failed to get tcp socket");
392
393 set.prune();
394 assert_eq!(set.len(), 0);
395 }
396}