wick_packet/
stream_map.rs

1use std::collections::HashMap;
2
3use tokio_stream::StreamExt;
4
5use crate::{Error, Packet, PacketSender, PacketStream};
6pub(crate) type Result<T> = std::result::Result<T, Error>;
7
8#[derive(Default)]
9#[must_use]
10/// A wrapper for a map of [String]s to [PacketStream]s.
11pub struct StreamMap {
12  inner: HashMap<String, PacketStream>,
13}
14
15impl std::fmt::Debug for StreamMap {
16  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17    f.debug_tuple("StreamMap").field(&self.inner.keys()).finish()
18  }
19}
20
21impl StreamMap {
22  /// Remove a stream from the map by key name.
23  pub fn take(&mut self, key: &str) -> Result<PacketStream> {
24    let v = self
25      .inner
26      .remove(key)
27      .ok_or_else(|| crate::Error::PortMissing(key.to_owned()))?;
28    Ok(v)
29  }
30
31  /// Get the keys in the map.
32  pub fn keys(&self) -> impl Iterator<Item = &String> {
33    self.inner.keys()
34  }
35
36  /// Take the next packet from the stream keyed by `key`.
37  pub async fn next_for(&mut self, key: &str) -> Option<Result<Packet>> {
38    let stream = self.inner.get_mut(key)?;
39    stream.next().await
40  }
41
42  /// Take one packet from each stream in the map. Returns an error if a complete set can't be made.
43  pub async fn next_set(&mut self) -> Result<Option<HashMap<String, Packet>>> {
44    let keys = self.inner.keys().cloned().collect::<Vec<_>>();
45    let mut raw = HashMap::new();
46    for key in keys {
47      let packet = self.next_for(&key).await;
48      raw.insert(key, packet);
49    }
50    if raw.values().all(|v| v.is_none()) {
51      Ok(None)
52    } else if let Some((name, _)) = raw.iter().find(|(_, p)| p.is_none()) {
53      Err(Error::StreamMapMissing(name.clone()))
54    } else {
55      let mut rv = HashMap::new();
56      for (key, packet) in raw {
57        let packet = packet.unwrap();
58        if let Err(e) = &packet {
59          return Err(Error::StreamMapError(key, e.to_string()));
60        }
61
62        rv.insert(key, packet.unwrap());
63      }
64      Ok(Some(rv))
65    }
66  }
67
68  #[cfg(feature = "rt-tokio")]
69  /// Turn a single [PacketStream] into a [StreamMap] keyed by the passed `ports`.
70  pub fn from_stream(mut stream: PacketStream, ports: impl IntoIterator<Item = String>) -> Self {
71    use tracing::warn;
72    use wasmrs_rx::Observer;
73
74    use crate::PacketExt;
75
76    #[must_use]
77    let mut streams = StreamMap::default();
78    let mut senders = HashMap::new();
79    for port in ports {
80      senders.insert(port.clone(), streams.init(&port));
81    }
82    tokio::spawn(async move {
83      while let Some(Ok(packet)) = stream.next().await {
84        if packet.is_fatal_error() {
85          for (name, sender) in &mut senders {
86            let _ = sender.send(packet.clone().to_port(name));
87          }
88        } else {
89          let Some(sender) = senders.get_mut(packet.port()) else {
90            if !packet.is_noop() {
91              warn!("received packet for unknown port: {}", packet.port());
92            }
93            continue;
94          };
95          let is_done = packet.is_done();
96          let _ = sender.send(packet);
97          if is_done {
98            sender.complete();
99          }
100        }
101      }
102    });
103    streams
104  }
105
106  pub fn init(&mut self, port: &str) -> PacketSender {
107    let flux = PacketSender::default();
108    self
109      .inner
110      .insert(port.to_owned(), PacketStream::new(Box::new(flux.take_rx().unwrap())));
111    flux
112  }
113}
114
115impl IntoIterator for StreamMap {
116  type Item = (String, PacketStream);
117
118  type IntoIter = std::collections::hash_map::IntoIter<String, PacketStream>;
119
120  fn into_iter(self) -> Self::IntoIter {
121    self.inner.into_iter()
122  }
123}