wasmflow_traits/
writable_port.rs1use serde::Serialize;
2use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
3use tokio_stream::wrappers::UnboundedReceiverStream;
4use tokio_stream::{StreamExt, StreamMap};
5use wasmflow_packet::v1::Packet as V1;
6use wasmflow_packet::{Packet, PacketWrapper};
7use wasmflow_streams::PacketStream;
8
9type Error = Box<dyn std::error::Error + Send + Sync>;
10
11type Result = std::result::Result<(), Error>;
12
13fn send_message(port: &PortChannel, name: impl AsRef<str>, packet: Packet) -> Result {
14 port.send(PacketWrapper {
15 payload: packet,
16 port: name.as_ref().to_owned(),
17 })
18}
19
20pub trait Writable {
22 type PayloadType: Serialize;
24
25 fn get_port(&self) -> std::result::Result<&PortChannel, Error>;
27
28 fn get_port_name(&self) -> &str;
30
31 fn get_id(&self) -> u32;
33
34 fn done(&self, data: Self::PayloadType) -> Result {
36 let port = self.get_port()?;
37 let name = self.get_port_name();
38 send_message(port, name, Packet::V1(V1::success(&data)))?;
39 send_message(port, name, Packet::V1(V1::done()))
40 }
41
42 fn done_message(&self, packet: Packet) -> Result {
44 let port = self.get_port()?;
45 let name = self.get_port_name();
46 send_message(port, name, packet)?;
47 send_message(port, name, Packet::V1(V1::done()))
48 }
49
50 fn done_exception(&self, payload: String) -> Result {
52 let port = self.get_port()?;
53 let name = self.get_port_name();
54 send_message(port, name, V1::exception(payload).into())?;
55 send_message(port, name, Packet::V1(V1::done()))
56 }
57}
58
59#[must_use]
61#[derive(Debug, Clone)]
62pub struct PortChannel {
63 pub name: String,
65 incoming: Option<UnboundedSender<PacketWrapper>>,
66}
67
68impl PortChannel {
69 pub fn new<T: AsRef<str>>(name: T) -> Self {
71 Self {
72 name: name.as_ref().to_owned(),
73 incoming: None,
74 }
75 }
76
77 pub fn open(&mut self) -> UnboundedReceiverStream<PacketWrapper> {
79 let (tx, rx) = unbounded_channel();
80 self.incoming = Some(tx);
81 UnboundedReceiverStream::new(rx)
82 }
83
84 pub fn close(&mut self) {
86 self.incoming.take();
87 }
88
89 #[must_use]
91 pub fn is_closed(&self) -> bool {
92 self.incoming.is_none()
93 }
94
95 pub fn send(&self, msg: PacketWrapper) -> Result {
97 let incoming = self
98 .incoming
99 .as_ref()
100 .ok_or_else::<Error, _>(|| "Send channel closed".into())?;
101 incoming.send(msg)?;
102 Ok(())
103 }
104
105 pub fn merge_all(buffer: &mut [&mut PortChannel]) -> PacketStream {
107 let mut channels = StreamMap::new();
108 for channel in buffer {
109 channels.insert(channel.name.clone(), channel.open());
110 }
111 let stream = channels.map(|(_, packet)| packet);
112
113 PacketStream::new(Box::new(stream))
114 }
115}
116
117#[cfg(test)]
118mod tests {
119
120 use wasmflow_packet::v1::Packet;
121 use wasmflow_transport::{TransportStream, TransportWrapper};
122
123 use super::*;
124 struct StringSender {
125 port: PortChannel,
126 }
127 impl Writable for StringSender {
128 type PayloadType = String;
129 fn get_port(&self) -> std::result::Result<&PortChannel, Error> {
130 Ok(&self.port)
131 }
132
133 fn get_port_name(&self) -> &str {
134 &self.port.name
135 }
136
137 fn get_id(&self) -> u32 {
138 0
139 }
140 }
141
142 struct I64Sender {
143 port: PortChannel,
144 }
145 impl Writable for I64Sender {
146 type PayloadType = i64;
147 fn get_port(&self) -> std::result::Result<&PortChannel, Error> {
148 Ok(&self.port)
149 }
150
151 fn get_port_name(&self) -> &str {
152 &self.port.name
153 }
154
155 fn get_id(&self) -> u32 {
156 0
157 }
158 }
159
160 #[test_log::test(tokio::test)]
161 async fn test_merge() -> Result {
162 let aggregated = {
165 let mut port1 = StringSender {
166 port: PortChannel::new("test1"),
167 };
168 let mut port2 = I64Sender {
169 port: PortChannel::new("test2"),
170 };
171
172 let aggregated = PortChannel::merge_all(&mut [&mut port1.port, &mut port2.port]);
173
174 port1.done("First".to_owned())?;
175 port2.done(1i64)?;
176
177 aggregated
178 };
179 let mut aggregated = TransportStream::new(aggregated.map(|pw| pw.into()));
180
181 let mut messages = aggregated.drain_port("test1").await?;
182 assert_eq!(messages.len(), 1);
183 assert_eq!(aggregated.buffered_size(), (1, 1));
184 let payload: String = messages.remove(0).deserialize().unwrap();
185 println!("Payload a1: {}", payload);
186 assert_eq!(payload, "First");
187
188 let mut messages = aggregated.drain_port("test2").await?;
189 assert_eq!(messages.len(), 1);
190 assert_eq!(aggregated.buffered_size(), (0, 0));
191 let payload: i64 = messages.remove(0).deserialize().unwrap();
192 println!("Payload b1: {}", payload);
193 assert_eq!(payload, 1);
194
195 Ok(())
196 }
197
198 #[test_log::test(tokio::test)]
199 async fn test_send() -> Result {
200 let mut port1 = StringSender {
201 port: PortChannel::new("test1"),
202 };
203 let mut rx = port1.port.open();
204
205 port1.done("first".to_owned())?;
206
207 let message: TransportWrapper = rx.next().await.unwrap().into();
208 let payload: String = message.payload.deserialize().unwrap();
209
210 assert_eq!(payload, "first");
211
212 Ok(())
213 }
214
215 #[test_log::test(tokio::test)]
216 async fn test_done() -> Result {
217 let mut port1 = StringSender {
218 port: PortChannel::new("test1"),
219 };
220 let mut rx = port1.port.open();
221
222 port1.done("done".to_owned())?;
223
224 let message: TransportWrapper = rx.next().await.unwrap().into();
225 let payload: String = message.payload.deserialize().unwrap();
226
227 assert_eq!(payload, "done");
228 let message = rx.next().await.unwrap();
229 assert_eq!(message.payload, Packet::done().into());
230 Ok(())
231 }
232
233 #[test_log::test(tokio::test)]
234 async fn test_exception() -> Result {
235 let mut port1 = StringSender {
236 port: PortChannel::new("test1"),
237 };
238 let mut rx = port1.port.open();
239
240 port1.done_exception("exc".to_owned())?;
241
242 let message = rx.next().await.unwrap();
243
244 assert_eq!(message.payload, Packet::exception("exc").into());
245
246 Ok(())
247 }
248
249 #[test_log::test(tokio::test)]
250 async fn test_done_exception() -> Result {
251 let mut port1 = StringSender {
252 port: PortChannel::new("test1"),
253 };
254 let mut rx = port1.port.open();
255
256 port1.done_exception("exc".to_owned())?;
257
258 let message = rx.next().await.unwrap();
259
260 assert_eq!(message.payload, Packet::exception("exc").into());
261 let message = rx.next().await.unwrap();
262 assert_eq!(message.payload, Packet::done().into());
263 Ok(())
264 }
265}