trillium_server_common/
binding.rs

1use crate::Transport;
2use futures_lite::{AsyncRead, AsyncWrite, Stream};
3use std::{
4    io::{IoSlice, Result},
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9/// A wrapper enum that has blanket implementations for common traits
10/// like TryFrom, Stream, AsyncRead, and AsyncWrite. This can contain
11/// listeners (like TcpListener), Streams (like Incoming), or
12/// bytestreams (like TcpStream).
13#[derive(Debug, Clone)]
14pub enum Binding<T, U> {
15    /// a tcp type (listener or incoming or stream)
16    Tcp(T),
17
18    /// a unix type (listener or incoming or stream)
19    Unix(U),
20}
21
22use Binding::{Tcp, Unix};
23
24impl<T, U> Binding<T, U> {
25    /// borrows the tcp stream or listener, if this is a tcp variant
26    pub fn get_tcp(&self) -> Option<&T> {
27        if let Tcp(t) = self {
28            Some(t)
29        } else {
30            None
31        }
32    }
33
34    /// borrows the unix stream or listener, if this is unix variant
35    pub fn get_unix(&self) -> Option<&U> {
36        if let Unix(u) = self {
37            Some(u)
38        } else {
39            None
40        }
41    }
42
43    /// mutably borrows the tcp stream or listener, if this is tcp variant
44    pub fn get_tcp_mut(&mut self) -> Option<&mut T> {
45        if let Tcp(t) = self {
46            Some(t)
47        } else {
48            None
49        }
50    }
51
52    /// mutably borrows the unix stream or listener, if this is unix variant
53    pub fn get_unix_mut(&mut self) -> Option<&mut U> {
54        if let Unix(u) = self {
55            Some(u)
56        } else {
57            None
58        }
59    }
60}
61
62impl<T: TryFrom<std::net::TcpListener>, U> TryFrom<std::net::TcpListener> for Binding<T, U> {
63    type Error = <T as TryFrom<std::net::TcpListener>>::Error;
64
65    fn try_from(value: std::net::TcpListener) -> std::result::Result<Self, Self::Error> {
66        Ok(Self::Tcp(value.try_into()?))
67    }
68}
69
70#[cfg(unix)]
71impl<T, U: TryFrom<std::os::unix::net::UnixListener>> TryFrom<std::os::unix::net::UnixListener>
72    for Binding<T, U>
73{
74    type Error = <U as TryFrom<std::os::unix::net::UnixListener>>::Error;
75
76    fn try_from(value: std::os::unix::net::UnixListener) -> std::result::Result<Self, Self::Error> {
77        Ok(Self::Unix(value.try_into()?))
78    }
79}
80
81impl<T, U, TI, UI> Stream for Binding<T, U>
82where
83    T: Stream<Item = Result<TI>> + Unpin,
84    U: Stream<Item = Result<UI>> + Unpin,
85{
86    type Item = Result<Binding<TI, UI>>;
87
88    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89        match &mut *self {
90            Tcp(t) => Pin::new(t).poll_next(cx).map(|i| i.map(|x| x.map(Tcp))),
91            Unix(u) => Pin::new(u).poll_next(cx).map(|i| i.map(|x| x.map(Unix))),
92        }
93    }
94}
95
96impl<T, U> Binding<T, U>
97where
98    T: AsyncRead + Unpin,
99    U: AsyncRead + Unpin,
100{
101    fn as_async_read(&mut self) -> Pin<&mut (dyn AsyncRead + Unpin)> {
102        Pin::new(match self {
103            Tcp(t) => t as &mut (dyn AsyncRead + Unpin),
104            Unix(u) => u as &mut (dyn AsyncRead + Unpin),
105        })
106    }
107}
108
109impl<T, U> Binding<T, U>
110where
111    T: AsyncWrite + Unpin,
112    U: AsyncWrite + Unpin,
113{
114    fn as_async_write(&mut self) -> Pin<&mut (dyn AsyncWrite + Unpin)> {
115        Pin::new(match self {
116            Tcp(t) => t as &mut (dyn AsyncWrite + Unpin),
117            Unix(u) => u as &mut (dyn AsyncWrite + Unpin),
118        })
119    }
120}
121
122impl<T, U> AsyncRead for Binding<T, U>
123where
124    T: AsyncRead + Unpin,
125    U: AsyncRead + Unpin,
126{
127    fn poll_read(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130        buf: &mut [u8],
131    ) -> Poll<Result<usize>> {
132        self.as_async_read().poll_read(cx, buf)
133    }
134
135    fn poll_read_vectored(
136        mut self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138        bufs: &mut [std::io::IoSliceMut<'_>],
139    ) -> Poll<Result<usize>> {
140        self.as_async_read().poll_read_vectored(cx, bufs)
141    }
142}
143
144impl<T, U> AsyncWrite for Binding<T, U>
145where
146    T: AsyncWrite + Unpin,
147    U: AsyncWrite + Unpin,
148{
149    fn poll_write(
150        mut self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &[u8],
153    ) -> Poll<Result<usize>> {
154        self.as_async_write().poll_write(cx, buf)
155    }
156
157    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
158        self.as_async_write().poll_flush(cx)
159    }
160
161    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
162        self.as_async_write().poll_close(cx)
163    }
164
165    fn poll_write_vectored(
166        mut self: Pin<&mut Self>,
167        cx: &mut Context<'_>,
168        bufs: &[IoSlice<'_>],
169    ) -> Poll<Result<usize>> {
170        self.as_async_write().poll_write_vectored(cx, bufs)
171    }
172}
173
174impl<T, U> Binding<T, U>
175where
176    T: Transport,
177    U: Transport,
178{
179    fn as_transport_mut(&mut self) -> &mut dyn Transport {
180        match self {
181            Tcp(t) => t as &mut dyn Transport,
182            Unix(u) => u as &mut dyn Transport,
183        }
184    }
185
186    fn as_transport(&self) -> &dyn Transport {
187        match self {
188            Tcp(t) => t as &dyn Transport,
189            Unix(u) => u as &dyn Transport,
190        }
191    }
192}
193
194impl<T, U> Transport for Binding<T, U>
195where
196    T: Transport,
197    U: Transport,
198{
199    fn set_linger(&mut self, linger: Option<std::time::Duration>) -> Result<()> {
200        self.as_transport_mut().set_linger(linger)
201    }
202
203    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
204        self.as_transport_mut().set_nodelay(nodelay)
205    }
206
207    fn set_ip_ttl(&mut self, ttl: u32) -> Result<()> {
208        self.as_transport_mut().set_ip_ttl(ttl)
209    }
210
211    fn peer_addr(&self) -> Result<Option<std::net::SocketAddr>> {
212        self.as_transport().peer_addr()
213    }
214}