tokio_shared_udp_socket/
lib.rs

1extern crate futures;
2extern crate future_utils;
3extern crate tokio;
4#[macro_use]
5extern crate unwrap;
6extern crate bytes;
7extern crate void;
8#[cfg(test)]
9#[macro_use]
10extern crate net_literals;
11#[macro_use]
12extern crate log;
13
14use std::{mem, io};
15use std::collections::{hash_map, HashMap};
16use std::sync::{Arc, Mutex};
17use std::net::SocketAddr;
18use bytes::{BytesMut, Bytes};
19use futures::{Async, AsyncSink, Stream, Sink};
20use future_utils::mpsc::{self, UnboundedReceiver, UnboundedSender};
21use tokio::net::UdpSocket;
22use void::{ResultVoidExt};
23
24/// A UDP socket that can easily be shared amongst a bunch of different futures.
25pub struct SharedUdpSocket {
26    inner: Arc<SharedUdpSocketInner>,
27}
28
29pub struct IncomingEndpoints {
30    inner: Arc<SharedUdpSocketInner>,
31    incoming_rx: UnboundedReceiver<UdpEndpoint>,
32    buffer: BytesMut,
33}
34
35/// A `Sink`/`Stream` that can be used to send/receive packets to/from a particular address.
36/// 
37/// These can be created by calling `SharedUdpSocket::endpoint`. `SharedUdpSocket` will also
38/// yield these (when used as a `Stream`) when it receives a packet from a new address.
39pub struct UdpEndpoint {
40    inner: Arc<SharedUdpSocketInner>,
41    incoming_rx: UnboundedReceiver<Bytes>,
42    addr: SocketAddr,
43    buffer: BytesMut,
44}
45
46struct SharedUdpSocketInner {
47    socket: Mutex<Option<UdpSocket>>,
48    endpoints: Mutex<HashMap<SocketAddr, UnboundedSender<Bytes>>>,
49    incoming_tx: Mutex<Option<UnboundedSender<UdpEndpoint>>>,
50}
51
52impl SharedUdpSocket {
53    /// Create a new `SharedUdpSocket` from a `UdpSocket`.
54    pub fn share(socket: UdpSocket) -> (SharedUdpSocket, IncomingEndpoints) {
55        trace!("creating shared udp socket on address {:?}", socket.local_addr());
56        let (tx, rx) = mpsc::unbounded();
57        let inner = SharedUdpSocketInner {
58            socket: Mutex::new(Some(socket)),
59            endpoints: Mutex::new(HashMap::new()),
60            incoming_tx: Mutex::new(Some(tx)),
61        };
62        let inner = Arc::new(inner);
63        let shared = SharedUdpSocket {
64            inner: inner.clone(),
65        };
66        let incoming = IncomingEndpoints {
67            inner,
68            incoming_rx: rx,
69            buffer: BytesMut::new(),
70        };
71        (shared, incoming)
72    }
73
74    /// Creates a `UdpEndpoint` object which receives all packets that arrive from the given
75    /// address. `UdpEndpoint` can also be used as a `Sink` to send packets. If another
76    /// `UdpEndpoint` with the given address already exists then it will no longer receive packets
77    /// since the newly created `UdpEndpoint` will take precedence.
78    pub fn endpoint(&self, addr: SocketAddr) -> UdpEndpoint {
79        let (tx, endpoint) = endpoint_new(&self.inner, addr);
80        let mut endpoints = unwrap!(self.inner.endpoints.lock());
81        let _ = endpoints.insert(addr, tx);
82        endpoint
83    }
84
85    /// Creates a `UdpEndpoint` object which receives all packets that arrive from the given
86    /// address. `UdpEndpoint` can also be used as a `Sink` to send packets. Unlike the `endpoint`
87    /// method, this method will not replace any pre-existing `UdpEndpoint` associated with the
88    /// given address and will instead return `None` if one exists.
89    pub fn try_endpoint(&self, addr: SocketAddr) -> Option<UdpEndpoint> {
90        let mut endpoints = unwrap!(self.inner.endpoints.lock());
91        match endpoints.entry(addr) {
92            hash_map::Entry::Occupied(..) => None,
93            hash_map::Entry::Vacant(ve) => {
94                let (tx, endpoint) = endpoint_new(&self.inner, addr);
95                let _ = ve.insert(tx);
96                Some(endpoint)
97            },
98        }
99    }
100
101    /// Steals the udp socket (if it hasn't already been stolen) causing all other
102    /// `SharedUdpSocket` and `UdpEndpoint` streams to end.
103    pub fn steal(self) -> Option<UdpSocket> {
104        let mut socket_opt = unwrap!(self.inner.socket.lock());
105        socket_opt.take()
106    }
107
108    pub fn local_addr(&self) -> io::Result<SocketAddr> {
109        self.inner.local_addr()
110    }
111}
112
113fn pump(inner: &Arc<SharedUdpSocketInner>, buffer: &mut BytesMut) -> io::Result<()> {
114    let mut socket_opt = unwrap!(inner.socket.lock());
115    let socket = match *socket_opt {
116        Some(ref mut socket) => socket,
117        None => return Ok(()),
118    };
119
120    loop {
121        let min_capacity = 64 * 1024 + 1;
122        let capacity = buffer.capacity();
123        if capacity < min_capacity {
124            buffer.reserve(min_capacity - capacity);
125        }
126        let capacity = buffer.capacity();
127        unsafe {
128            buffer.set_len(capacity)
129        }
130        match socket.poll_recv_from(&mut *buffer) {
131            Ok(Async::Ready((n, addr))) => {
132                if n == buffer.len() {
133                    return Err(io::Error::new(
134                        io::ErrorKind::Other,
135                        "failed to recv entire dgram",
136                    ));
137                }
138                let data = buffer.split_to(n).freeze();
139                let mut endpoints = unwrap!(inner.endpoints.lock());
140                let drop_after_unlock = match endpoints.entry(addr) {
141                    hash_map::Entry::Occupied(mut oe) => {
142                        match oe.get().unbounded_send(data) {
143                            Ok(()) => None,
144                            Err(send_error) => {
145                                if let Some(ref incoming_tx) = *unwrap!(inner.incoming_tx.lock()) {
146                                    let (tx, endpoint) = endpoint_new(inner, addr);
147
148                                    unwrap!(tx.unbounded_send(send_error.into_inner()));
149                                    let _ = mem::replace(oe.get_mut(), tx);
150                                    match incoming_tx.unbounded_send(endpoint) {
151                                        Ok(()) => None,
152                                        Err(send_error) => Some(send_error.into_inner()),
153                                    }
154                                } else {
155                                    None
156                                }
157                            },
158                        }
159                    },
160                    hash_map::Entry::Vacant(ve) => {
161                        if let Some(ref incoming_tx) = *unwrap!(inner.incoming_tx.lock()) {
162                            let (tx, endpoint) = endpoint_new(inner, addr);
163
164                            unwrap!(tx.unbounded_send(data));
165                            ve.insert(tx);
166                            match incoming_tx.unbounded_send(endpoint) {
167                                Ok(()) => None,
168                                Err(send_error) => Some(send_error.into_inner()),
169                            }
170                        } else {
171                            None
172                        }
173                    },
174                };
175                drop(endpoints);
176                drop(drop_after_unlock);
177            },
178            Ok(Async::NotReady) => return Ok(()),
179            Err(e) => {
180                match e.kind() {
181                    io::ErrorKind::WouldBlock => return Ok(()),
182                    io::ErrorKind::ConnectionReset => continue,
183                    _ => return Err(e),
184                }
185            },
186        }
187    }
188}
189
190fn endpoint_new(inner: &Arc<SharedUdpSocketInner>, addr: SocketAddr) -> (UnboundedSender<Bytes>, UdpEndpoint) {
191    let (tx, rx) = mpsc::unbounded();
192    let inner = inner.clone();
193    let endpoint = UdpEndpoint {
194        inner: inner,
195        incoming_rx: rx,
196        addr: addr,
197        buffer: BytesMut::new(),
198    };
199    (tx, endpoint)
200}
201
202impl UdpEndpoint {
203    /// Get the remote address that this `UdpEndpoint` sends/receives packets to/from.
204    pub fn remote_addr(&self) -> SocketAddr {
205        self.addr
206    }
207
208    /// Steals the udp socket (if it hasn't already been stolen) causing all other
209    /// `SharedUdpSocket` and `UdpEndpoint` streams to end.
210    pub fn steal(self) -> Option<UdpSocket> {
211        let mut socket_opt = unwrap!(self.inner.socket.lock());
212        socket_opt.take()
213    }
214
215    pub fn local_addr(&self) -> io::Result<SocketAddr> {
216        self.inner.local_addr()
217    }
218}
219
220impl SharedUdpSocketInner {
221    pub fn local_addr(&self) -> io::Result<SocketAddr> {
222        let socket_opt = unwrap!(self.socket.lock());
223        match *socket_opt {
224            Some(ref socket) => socket.local_addr(),
225            None => Err(io::Error::new(io::ErrorKind::Other, "socket has been stolen")),
226        }
227    }
228}
229
230impl Stream for IncomingEndpoints {
231    type Item = UdpEndpoint;
232    type Error = io::Error;
233
234    fn poll(&mut self) -> io::Result<Async<Option<UdpEndpoint>>> {
235        pump(&self.inner, &mut self.buffer)?;
236
237        Ok(self.incoming_rx.poll().void_unwrap())
238    }
239}
240
241impl Stream for UdpEndpoint {
242    type Item = Bytes;
243    type Error = io::Error;
244
245    fn poll(&mut self) -> io::Result<Async<Option<Bytes>>> {
246        pump(&self.inner, &mut self.buffer)?;
247
248        Ok(self.incoming_rx.poll().void_unwrap())
249    }
250}
251
252impl Sink for UdpEndpoint {
253    type SinkItem = Bytes;
254    type SinkError = io::Error;
255
256    fn start_send(&mut self, item: Bytes) -> io::Result<AsyncSink<Bytes>> {
257        let mut socket_opt = unwrap!(self.inner.socket.lock());
258        let socket = match *socket_opt {
259            Some(ref mut socket) => socket,
260            None => return Err(io::ErrorKind::NotConnected.into()),
261        };
262
263        match socket.poll_send_to(&item, &self.addr) {
264            Ok(Async::Ready(n)) => {
265                if n != item.len() {
266                    return Err(io::Error::new(
267                        io::ErrorKind::Other,
268                        "failed to send entire dgram",
269                    ));
270                }
271                return Ok(AsyncSink::Ready);
272            },
273            Ok(Async::NotReady) => return Ok(AsyncSink::NotReady(item)),
274            Err(e) => {
275                if e.kind() == io::ErrorKind::WouldBlock {
276                    return Ok(AsyncSink::NotReady(item));
277                }
278                return Err(e);
279            },
280        }
281    }
282
283    fn poll_complete(&mut self) -> io::Result<Async<()>> {
284        Ok(Async::Ready(()))
285    }
286}
287
288impl Drop for SharedUdpSocket {
289    fn drop(&mut self) {
290        let mut incoming_tx = unwrap!(self.inner.incoming_tx.lock());
291        *incoming_tx = None;
292    }
293}
294
295impl Drop for UdpEndpoint {
296    fn drop(&mut self) {
297        let mut endpoints = unwrap!(self.inner.endpoints.lock());
298        let _ = endpoints.remove(&self.addr);
299    }
300}
301
302#[cfg(test)]
303mod test {
304    use super::*;
305    use futures::Future;
306
307    #[test]
308    fn test() {
309        let sock0 = unwrap!(UdpSocket::bind(&addr!("127.0.0.1:0")));
310        let addr0 = unwrap!(sock0.local_addr());
311        let sock1 = unwrap!(UdpSocket::bind(&addr!("127.0.0.1:0")));
312        let addr1 = unwrap!(sock1.local_addr());
313
314        let shared = unwrap!(UdpSocket::bind(&addr!("127.0.0.1:0")));
315        let shared_addr = unwrap!(shared.local_addr());
316        let (_shared, incoming) = SharedUdpSocket::share(shared);
317
318        tokio::run({
319            sock0
320            .send_dgram(b"qqqq", &shared_addr)
321            .map_err(|e| panic!("{}", e))
322            .and_then(move |(sock0, _)| {
323                incoming
324                .into_future()
325                .map_err(|(e, _)| panic!("{}", e))
326                .and_then(move |(opt, shared)| {
327                    let endpoint_0 = unwrap!(opt);
328                    assert_eq!(endpoint_0.remote_addr(), addr0);
329
330                    endpoint_0
331                    .into_future()
332                    .map_err(|(e, _)| panic!("{}", e))
333                    .and_then(move |(opt, endpoint_0)| {
334                        let data = unwrap!(opt);
335                        assert_eq!(&data[..], b"qqqq");
336
337                        sock0
338                        .send_dgram(b"wwww", &shared_addr)
339                        .map_err(|e| panic!("{}", e))
340                        .and_then(move |(sock0, _)| {
341                            sock1
342                            .send_dgram(b"eeee", &shared_addr)
343                            .map_err(|e| panic!("{}", e))
344                            .and_then(move |_sock1| {
345                                shared
346                                .into_future()
347                                .map_err(|(e, _)| panic!("{}", e))
348                                .and_then(move |(opt, shared)| {
349                                    let endpoint_1 = unwrap!(opt);
350                                    assert_eq!(endpoint_1.remote_addr(), addr1);
351                                    drop(shared);
352
353                                    endpoint_1
354                                    .into_future()
355                                    .map_err(|(e, _)| panic!("{}", e))
356                                    .and_then(move |(opt, _endpoint_1)| {
357                                        let data = unwrap!(opt);
358                                        assert_eq!(&data[..], b"eeee");
359
360                                        endpoint_0
361                                        .into_future()
362                                        .map_err(|(e, _)| panic!("{}", e))
363                                        .and_then(move |(opt, endpoint_0)| {
364                                            let data = unwrap!(opt);
365                                            assert_eq!(&data[..], b"wwww");
366
367                                            endpoint_0
368                                            .send(Bytes::from(&b"rrrr"[..]))
369                                            .map_err(|e| panic!("{}", e))
370                                            .and_then(move |endpoint_0| {
371                                                let buff = [0; 10];
372
373                                                sock0
374                                                .recv_dgram(buff)
375                                                .map_err(|e| panic!("{}", e))
376                                                .map(move |(_sock0, data, len, addr)| {
377                                                    assert_eq!(addr, shared_addr);
378                                                    assert_eq!(&data[..len], b"rrrr");
379                                                    assert!(endpoint_0.steal().is_some());
380                                                })
381                                            })
382                                        })
383                                    })
384                                })
385                            })
386                        })
387                    })
388                })
389            })
390        });
391    }
392}
393