1use std::io;
8
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13use vox_types::{Backing, Link, LinkRx, LinkTx, LinkTxPermit, WriteSlot};
14
15#[cfg(not(target_arch = "wasm32"))]
16use vox_core::{Attachment, LinkSource};
17
18pub struct StreamLink<R, W> {
26 reader: R,
27 writer: W,
28}
29
30impl<R, W> StreamLink<R, W> {
31 pub fn new(reader: R, writer: W) -> Self {
33 Self { reader, writer }
34 }
35}
36
37impl StreamLink<tokio::net::tcp::OwnedReadHalf, tokio::net::tcp::OwnedWriteHalf> {
38 pub fn tcp(stream: tokio::net::TcpStream) -> Self {
40 let (r, w) = stream.into_split();
41 Self::new(r, w)
42 }
43}
44
45#[cfg(not(target_arch = "wasm32"))]
46pub struct TcpConnector {
47 addr: String,
48 nodelay: bool,
49}
50
51#[cfg(not(target_arch = "wasm32"))]
52pub fn tcp_connector(addr: impl Into<String>) -> TcpConnector {
53 TcpConnector::new(addr)
54}
55
56#[cfg(not(target_arch = "wasm32"))]
57impl TcpConnector {
58 pub fn new(addr: impl Into<String>) -> Self {
59 Self {
60 addr: addr.into(),
61 nodelay: true,
62 }
63 }
64
65 pub fn nodelay(mut self, nodelay: bool) -> Self {
66 self.nodelay = nodelay;
67 self
68 }
69}
70
71#[cfg(not(target_arch = "wasm32"))]
72impl LinkSource for TcpConnector {
73 type Link = StreamLink<tokio::net::tcp::OwnedReadHalf, tokio::net::tcp::OwnedWriteHalf>;
74
75 async fn next_link(&mut self) -> io::Result<Attachment<Self::Link>> {
76 let stream = tokio::net::TcpStream::connect(&self.addr).await?;
77 stream.set_nodelay(self.nodelay)?;
78 Ok(Attachment::initiator(StreamLink::tcp(stream)))
79 }
80}
81
82impl StreamLink<tokio::io::Stdin, tokio::io::Stdout> {
83 pub fn stdio() -> Self {
85 Self::new(tokio::io::stdin(), tokio::io::stdout())
86 }
87}
88
89#[cfg(unix)]
90impl StreamLink<tokio::net::unix::OwnedReadHalf, tokio::net::unix::OwnedWriteHalf> {
91 pub fn unix(stream: tokio::net::UnixStream) -> Self {
93 let (r, w) = stream.into_split();
94 Self::new(r, w)
95 }
96}
97
98#[cfg(windows)]
99impl
100 StreamLink<
101 tokio::io::ReadHalf<tokio::net::windows::named_pipe::NamedPipeClient>,
102 tokio::io::WriteHalf<tokio::net::windows::named_pipe::NamedPipeClient>,
103 >
104{
105 pub fn named_pipe_client(pipe: tokio::net::windows::named_pipe::NamedPipeClient) -> Self {
107 let (r, w) = tokio::io::split(pipe);
108 Self::new(r, w)
109 }
110}
111
112impl<R, W> Link for StreamLink<R, W>
113where
114 R: AsyncRead + Send + Unpin + 'static,
115 W: AsyncWrite + Send + Unpin + 'static,
116{
117 type Tx = StreamLinkTx;
118 type Rx = StreamLinkRx<BufReader<R>>;
119
120 fn split(self) -> (Self::Tx, Self::Rx) {
121 let (tx_chan, mut rx_chan) = mpsc::channel::<Vec<u8>>(1);
122 let (buf_return_tx, buf_return_rx) = mpsc::unbounded_channel::<Vec<u8>>();
126 let mut writer = BufWriter::new(self.writer);
127
128 let writer_task = tokio::spawn(async move {
129 while let Some(mut bytes) = rx_chan.recv().await {
130 writer
131 .write_all(&(bytes.len() as u32).to_le_bytes())
132 .await?;
133 writer.write_all(&bytes).await?;
134 bytes.clear();
136 let _ = buf_return_tx.send(bytes);
137 while let Ok(mut bytes) = rx_chan.try_recv() {
140 writer
141 .write_all(&(bytes.len() as u32).to_le_bytes())
142 .await?;
143 writer.write_all(&bytes).await?;
144 bytes.clear();
145 let _ = buf_return_tx.send(bytes);
146 }
147 writer.flush().await?;
148 }
149 writer.shutdown().await?;
150 Ok(())
151 });
152
153 (
154 StreamLinkTx {
155 tx: tx_chan,
156 buf_pool: std::sync::Mutex::new(buf_return_rx),
157 writer_task,
158 },
159 StreamLinkRx {
160 reader: BufReader::new(self.reader),
161 },
162 )
163 }
164}
165
166pub struct StreamLinkTx {
176 tx: mpsc::Sender<Vec<u8>>,
177 buf_pool: std::sync::Mutex<mpsc::UnboundedReceiver<Vec<u8>>>,
178 writer_task: JoinHandle<io::Result<()>>,
179}
180
181pub struct StreamLinkTxPermit {
183 permit: mpsc::OwnedPermit<Vec<u8>>,
184 recycled_buf: Option<Vec<u8>>,
185}
186
187pub struct StreamWriteSlot {
189 buf: Vec<u8>,
190 permit: mpsc::OwnedPermit<Vec<u8>>,
191}
192
193impl LinkTx for StreamLinkTx {
194 type Permit = StreamLinkTxPermit;
195
196 async fn reserve(&self) -> io::Result<Self::Permit> {
197 let permit = self.tx.clone().reserve_owned().await.map_err(|_| {
198 io::Error::new(io::ErrorKind::ConnectionReset, "stream writer task stopped")
199 })?;
200 let recycled_buf = self.buf_pool.lock().unwrap().try_recv().ok();
202 Ok(StreamLinkTxPermit {
203 permit,
204 recycled_buf,
205 })
206 }
207
208 async fn close(self) -> io::Result<()> {
209 drop(self.tx);
210 self.writer_task.await.map_err(io::Error::other)?
211 }
212}
213
214impl LinkTxPermit for StreamLinkTxPermit {
216 type Slot = StreamWriteSlot;
217
218 fn alloc(self, len: usize) -> io::Result<Self::Slot> {
219 let mut buf = self.recycled_buf.unwrap_or_default();
220 buf.resize(len, 0);
221 Ok(StreamWriteSlot {
222 buf,
223 permit: self.permit,
224 })
225 }
226}
227
228impl WriteSlot for StreamWriteSlot {
229 fn as_mut_slice(&mut self) -> &mut [u8] {
230 &mut self.buf
231 }
232
233 fn commit(self) {
234 drop(self.permit.send(self.buf));
235 }
236}
237
238pub struct StreamLinkRx<R> {
244 reader: R,
245}
246
247impl<R: AsyncRead + Send + Unpin + 'static> LinkRx for StreamLinkRx<R> {
249 type Error = io::Error;
250
251 async fn recv(&mut self) -> io::Result<Option<Backing>> {
252 let mut len_buf = [0u8; 4];
253 match self.reader.read_exact(&mut len_buf).await {
254 Ok(_) => {}
255 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
256 Err(e) => return Err(e),
257 }
258 let len = u32::from_le_bytes(len_buf) as usize;
259 let mut buf = vec![0u8; len];
260 self.reader.read_exact(&mut buf).await?;
261 Ok(Some(Backing::Boxed(buf.into_boxed_slice())))
262 }
263}
264
265type BoxReader = Box<dyn AsyncRead + Send + Unpin>;
270type BoxWriter = Box<dyn AsyncWrite + Send + Unpin>;
271
272pub struct LocalLink {
279 inner: StreamLink<BoxReader, BoxWriter>,
280}
281
282impl LocalLink {
283 #[cfg(unix)]
285 pub async fn connect(addr: &str) -> io::Result<Self> {
286 let stream = tokio::net::UnixStream::connect(addr).await?;
287 let (r, w) = stream.into_split();
288 Ok(Self {
289 inner: StreamLink::new(Box::new(r), Box::new(w)),
290 })
291 }
292
293 #[cfg(windows)]
295 pub async fn connect(addr: &str) -> io::Result<Self> {
296 let pipe = tokio::net::windows::named_pipe::ClientOptions::new().open(addr)?;
297 let (r, w) = tokio::io::split(pipe);
298 Ok(Self {
299 inner: StreamLink::new(Box::new(r), Box::new(w)),
300 })
301 }
302}
303
304impl Link for LocalLink {
305 type Tx = StreamLinkTx;
306 type Rx = StreamLinkRx<BufReader<BoxReader>>;
307
308 fn split(self) -> (Self::Tx, Self::Rx) {
309 self.inner.split()
310 }
311}
312
313pub struct LocalLinkAcceptor {
320 #[cfg(unix)]
321 listener: tokio::net::UnixListener,
322 #[cfg(windows)]
326 addr: String,
327 #[cfg(windows)]
328 pending: moire::sync::Mutex<tokio::net::windows::named_pipe::NamedPipeServer>,
329}
330
331impl LocalLinkAcceptor {
332 #[cfg(unix)]
334 pub fn bind(addr: impl Into<String>) -> io::Result<Self> {
335 let listener = tokio::net::UnixListener::bind(addr.into())?;
336 Ok(Self { listener })
337 }
338
339 #[cfg(windows)]
341 pub fn bind(addr: impl Into<String>) -> io::Result<Self> {
342 use tokio::net::windows::named_pipe::ServerOptions;
343 let addr = addr.into();
344 let server = ServerOptions::new()
345 .first_pipe_instance(true)
346 .create(&addr)?;
347 Ok(Self {
348 addr,
349 pending: moire::sync::Mutex::new("local-link-acceptor.pending", server),
350 })
351 }
352
353 #[cfg(unix)]
355 pub async fn accept(&self) -> io::Result<LocalLink> {
356 let (stream, _addr) = self.listener.accept().await?;
357 let (r, w) = stream.into_split();
358 Ok(LocalLink {
359 inner: StreamLink::new(Box::new(r), Box::new(w)),
360 })
361 }
362
363 #[cfg(windows)]
365 pub async fn accept(&self) -> io::Result<LocalLink> {
366 use tokio::net::windows::named_pipe::ServerOptions;
367 let mut guard = self.pending.lock().await;
368 guard.connect().await?;
369 let next = ServerOptions::new().create(&self.addr)?;
370 let connected = std::mem::replace(&mut *guard, next);
371 drop(guard);
372 let (r, w) = tokio::io::split(connected);
373 Ok(LocalLink {
374 inner: StreamLink::new(Box::new(r), Box::new(w)),
375 })
376 }
377}
378
379#[cfg(test)]
384mod tests {
385 use tokio::io::split;
386 use vox_types::{Backing, Link, LinkRx, LinkTx, LinkTxPermit, WriteSlot};
387
388 use super::*;
389
390 type DuplexRead = tokio::io::ReadHalf<tokio::io::DuplexStream>;
391 type DuplexWrite = tokio::io::WriteHalf<tokio::io::DuplexStream>;
392 type DuplexLink = StreamLink<DuplexRead, DuplexWrite>;
393
394 fn duplex_pair() -> (DuplexLink, DuplexLink) {
396 let (a, b) = tokio::io::duplex(4096);
397 let (a_r, a_w) = split(a);
398 let (b_r, b_w) = split(b);
399 (StreamLink::new(a_r, a_w), StreamLink::new(b_r, b_w))
400 }
401
402 fn payload(link: &Backing) -> &[u8] {
403 match link {
404 Backing::Boxed(b) => b,
405 Backing::Shared(s) => s.as_bytes(),
406 }
407 }
408
409 #[tokio::test]
410 async fn round_trip_single() {
411 let (a, b) = duplex_pair();
412 let (tx_a, _rx_a) = a.split();
413 let (_tx_b, mut rx_b) = b.split();
414
415 let permit = tx_a.reserve().await.unwrap();
416 let mut slot = permit.alloc(5).unwrap();
417 slot.as_mut_slice().copy_from_slice(b"hello");
418 slot.commit();
419
420 let msg = rx_b.recv().await.unwrap().unwrap();
421 assert_eq!(payload(&msg), b"hello");
422 }
423
424 #[tokio::test]
425 async fn multiple_messages_in_order() {
426 let (a, b) = duplex_pair();
427 let (tx_a, _rx_a) = a.split();
428 let (_tx_b, mut rx_b) = b.split();
429
430 let payloads: &[&[u8]] = &[b"one", b"two", b"three", b"four"];
431 for p in payloads {
432 let permit = tx_a.reserve().await.unwrap();
433 let mut slot = permit.alloc(p.len()).unwrap();
434 slot.as_mut_slice().copy_from_slice(p);
435 slot.commit();
436 }
437
438 for expected in payloads {
439 let msg = rx_b.recv().await.unwrap().unwrap();
440 assert_eq!(payload(&msg), *expected);
441 }
442 }
443
444 #[tokio::test]
446 async fn empty_payload() {
447 let (a, b) = duplex_pair();
448 let (tx_a, _rx_a) = a.split();
449 let (_tx_b, mut rx_b) = b.split();
450
451 let permit = tx_a.reserve().await.unwrap();
452 let slot = permit.alloc(0).unwrap();
453 slot.commit();
454
455 let msg = rx_b.recv().await.unwrap().unwrap();
456 assert_eq!(payload(&msg), b"");
457 }
458
459 #[tokio::test]
461 async fn eof_on_peer_close() {
462 let (a, b) = duplex_pair();
463 let (tx_a, _rx_a) = a.split();
464 let (_tx_b, mut rx_b) = b.split();
465
466 tx_a.close().await.unwrap();
467
468 assert!(rx_b.recv().await.unwrap().is_none());
469 assert!(rx_b.recv().await.unwrap().is_none());
471 }
472
473 #[tokio::test]
475 async fn dropped_permit_sends_nothing() {
476 let (a, b) = duplex_pair();
477 let (tx_a, _rx_a) = a.split();
478 let (_tx_b, mut rx_b) = b.split();
479
480 let permit = tx_a.reserve().await.unwrap();
482 drop(permit);
483
484 let permit = tx_a.reserve().await.unwrap();
486 let mut slot = permit.alloc(3).unwrap();
487 slot.as_mut_slice().copy_from_slice(b"yep");
488 slot.commit();
489
490 let msg = rx_b.recv().await.unwrap().unwrap();
491 assert_eq!(payload(&msg), b"yep");
492 }
493
494 #[tokio::test]
496 async fn dropped_slot_sends_nothing() {
497 let (a, b) = duplex_pair();
498 let (tx_a, _rx_a) = a.split();
499 let (_tx_b, mut rx_b) = b.split();
500
501 let permit = tx_a.reserve().await.unwrap();
503 let slot = permit.alloc(3).unwrap();
504 drop(slot);
505
506 let permit = tx_a.reserve().await.unwrap();
508 let mut slot = permit.alloc(2).unwrap();
509 slot.as_mut_slice().copy_from_slice(b"ok");
510 slot.commit();
511
512 let msg = rx_b.recv().await.unwrap().unwrap();
513 assert_eq!(payload(&msg), b"ok");
514 }
515
516 #[cfg(unix)]
517 #[tokio::test]
518 async fn local_link_round_trip() {
519 let dir = tempfile::tempdir().unwrap();
520 let path = dir.path().join("test.sock");
521 let addr = path.to_str().unwrap();
522
523 let acceptor = LocalLinkAcceptor::bind(addr).unwrap();
524
525 let connect_addr = addr.to_string();
526 let server = tokio::spawn(async move {
527 let link = acceptor.accept().await.unwrap();
528 let (_tx, mut rx) = link.split();
529 rx.recv().await.unwrap().unwrap()
530 });
531
532 let client_link = LocalLink::connect(&connect_addr).await.unwrap();
533 let (tx, _rx) = client_link.split();
534 let permit = tx.reserve().await.unwrap();
535 let mut slot = permit.alloc(5).unwrap();
536 slot.as_mut_slice().copy_from_slice(b"local");
537 slot.commit();
538
539 let msg = server.await.unwrap();
540 assert_eq!(payload(&msg), b"local");
541 }
542}