wasmflow_traits/
writable_port.rs

1use serde::Serialize;
2use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
3use tokio_stream::wrappers::UnboundedReceiverStream;
4use tokio_stream::{StreamExt, StreamMap};
5use wasmflow_packet::v1::Packet as V1;
6use wasmflow_packet::{Packet, PacketWrapper};
7use wasmflow_streams::PacketStream;
8
9type Error = Box<dyn std::error::Error + Send + Sync>;
10
11type Result = std::result::Result<(), Error>;
12
13fn send_message(port: &PortChannel, name: impl AsRef<str>, packet: Packet) -> Result {
14  port.send(PacketWrapper {
15    payload: packet,
16    port: name.as_ref().to_owned(),
17  })
18}
19
20/// The native PortSender trait. This trait encapsulates sending messages out of native ports.
21pub trait Writable {
22  /// The type of data that the port outputs.
23  type PayloadType: Serialize;
24
25  /// Get the port buffer that the sender can push to.
26  fn get_port(&self) -> std::result::Result<&PortChannel, Error>;
27
28  /// Get the port's name.
29  fn get_port_name(&self) -> &str;
30
31  /// Return the ID of the transaction.
32  fn get_id(&self) -> u32;
33
34  /// Send a message then close the port.
35  fn done(&self, data: Self::PayloadType) -> Result {
36    let port = self.get_port()?;
37    let name = self.get_port_name();
38    send_message(port, name, Packet::V1(V1::success(&data)))?;
39    send_message(port, name, Packet::V1(V1::done()))
40  }
41
42  /// Send a payload then close the port.
43  fn done_message(&self, packet: Packet) -> Result {
44    let port = self.get_port()?;
45    let name = self.get_port_name();
46    send_message(port, name, packet)?;
47    send_message(port, name, Packet::V1(V1::done()))
48  }
49
50  /// Send an exception then close the port.
51  fn done_exception(&self, payload: String) -> Result {
52    let port = self.get_port()?;
53    let name = self.get_port_name();
54    send_message(port, name, V1::exception(payload).into())?;
55    send_message(port, name, Packet::V1(V1::done()))
56  }
57}
58
59/// A [PortChannel] wraps an unbounded channel with a port name.
60#[must_use]
61#[derive(Debug, Clone)]
62pub struct PortChannel {
63  /// Port name.
64  pub name: String,
65  incoming: Option<UnboundedSender<PacketWrapper>>,
66}
67
68impl PortChannel {
69  /// Constructor for a [PortChannel].
70  pub fn new<T: AsRef<str>>(name: T) -> Self {
71    Self {
72      name: name.as_ref().to_owned(),
73      incoming: None,
74    }
75  }
76
77  /// Initialize the [PortChannel] and return a receiver.
78  pub fn open(&mut self) -> UnboundedReceiverStream<PacketWrapper> {
79    let (tx, rx) = unbounded_channel();
80    self.incoming = Some(tx);
81    UnboundedReceiverStream::new(rx)
82  }
83
84  /// Drop the incoming channel, closing the upstream.
85  pub fn close(&mut self) {
86    self.incoming.take();
87  }
88
89  /// Returns true if the port still has an active upstream.
90  #[must_use]
91  pub fn is_closed(&self) -> bool {
92    self.incoming.is_none()
93  }
94
95  /// Send a messages to the channel.
96  pub fn send(&self, msg: PacketWrapper) -> Result {
97    let incoming = self
98      .incoming
99      .as_ref()
100      .ok_or_else::<Error, _>(|| "Send channel closed".into())?;
101    incoming.send(msg)?;
102    Ok(())
103  }
104
105  /// Merge a list of [PortChannel]s into a TransportStream.
106  pub fn merge_all(buffer: &mut [&mut PortChannel]) -> PacketStream {
107    let mut channels = StreamMap::new();
108    for channel in buffer {
109      channels.insert(channel.name.clone(), channel.open());
110    }
111    let stream = channels.map(|(_, packet)| packet);
112
113    PacketStream::new(Box::new(stream))
114  }
115}
116
117#[cfg(test)]
118mod tests {
119
120  use wasmflow_packet::v1::Packet;
121  use wasmflow_transport::{TransportStream, TransportWrapper};
122
123  use super::*;
124  struct StringSender {
125    port: PortChannel,
126  }
127  impl Writable for StringSender {
128    type PayloadType = String;
129    fn get_port(&self) -> std::result::Result<&PortChannel, Error> {
130      Ok(&self.port)
131    }
132
133    fn get_port_name(&self) -> &str {
134      &self.port.name
135    }
136
137    fn get_id(&self) -> u32 {
138      0
139    }
140  }
141
142  struct I64Sender {
143    port: PortChannel,
144  }
145  impl Writable for I64Sender {
146    type PayloadType = i64;
147    fn get_port(&self) -> std::result::Result<&PortChannel, Error> {
148      Ok(&self.port)
149    }
150
151    fn get_port_name(&self) -> &str {
152      &self.port.name
153    }
154
155    fn get_id(&self) -> u32 {
156      0
157    }
158  }
159
160  #[test_log::test(tokio::test)]
161  async fn test_merge() -> Result {
162    // This sets up the ports, sends data on them, then
163    // drops the ports, thus closing them.
164    let aggregated = {
165      let mut port1 = StringSender {
166        port: PortChannel::new("test1"),
167      };
168      let mut port2 = I64Sender {
169        port: PortChannel::new("test2"),
170      };
171
172      let aggregated = PortChannel::merge_all(&mut [&mut port1.port, &mut port2.port]);
173
174      port1.done("First".to_owned())?;
175      port2.done(1i64)?;
176
177      aggregated
178    };
179    let mut aggregated = TransportStream::new(aggregated.map(|pw| pw.into()));
180
181    let mut messages = aggregated.drain_port("test1").await?;
182    assert_eq!(messages.len(), 1);
183    assert_eq!(aggregated.buffered_size(), (1, 1));
184    let payload: String = messages.remove(0).deserialize().unwrap();
185    println!("Payload a1: {}", payload);
186    assert_eq!(payload, "First");
187
188    let mut messages = aggregated.drain_port("test2").await?;
189    assert_eq!(messages.len(), 1);
190    assert_eq!(aggregated.buffered_size(), (0, 0));
191    let payload: i64 = messages.remove(0).deserialize().unwrap();
192    println!("Payload b1: {}", payload);
193    assert_eq!(payload, 1);
194
195    Ok(())
196  }
197
198  #[test_log::test(tokio::test)]
199  async fn test_send() -> Result {
200    let mut port1 = StringSender {
201      port: PortChannel::new("test1"),
202    };
203    let mut rx = port1.port.open();
204
205    port1.done("first".to_owned())?;
206
207    let message: TransportWrapper = rx.next().await.unwrap().into();
208    let payload: String = message.payload.deserialize().unwrap();
209
210    assert_eq!(payload, "first");
211
212    Ok(())
213  }
214
215  #[test_log::test(tokio::test)]
216  async fn test_done() -> Result {
217    let mut port1 = StringSender {
218      port: PortChannel::new("test1"),
219    };
220    let mut rx = port1.port.open();
221
222    port1.done("done".to_owned())?;
223
224    let message: TransportWrapper = rx.next().await.unwrap().into();
225    let payload: String = message.payload.deserialize().unwrap();
226
227    assert_eq!(payload, "done");
228    let message = rx.next().await.unwrap();
229    assert_eq!(message.payload, Packet::done().into());
230    Ok(())
231  }
232
233  #[test_log::test(tokio::test)]
234  async fn test_exception() -> Result {
235    let mut port1 = StringSender {
236      port: PortChannel::new("test1"),
237    };
238    let mut rx = port1.port.open();
239
240    port1.done_exception("exc".to_owned())?;
241
242    let message = rx.next().await.unwrap();
243
244    assert_eq!(message.payload, Packet::exception("exc").into());
245
246    Ok(())
247  }
248
249  #[test_log::test(tokio::test)]
250  async fn test_done_exception() -> Result {
251    let mut port1 = StringSender {
252      port: PortChannel::new("test1"),
253    };
254    let mut rx = port1.port.open();
255
256    port1.done_exception("exc".to_owned())?;
257
258    let message = rx.next().await.unwrap();
259
260    assert_eq!(message.payload, Packet::exception("exc").into());
261    let message = rx.next().await.unwrap();
262    assert_eq!(message.payload, Packet::done().into());
263    Ok(())
264  }
265}