1extern crate futures;
2extern crate future_utils;
3extern crate tokio;
4#[macro_use]
5extern crate unwrap;
6extern crate bytes;
7extern crate void;
8#[cfg(test)]
9#[macro_use]
10extern crate net_literals;
11#[macro_use]
12extern crate log;
13
14use std::{mem, io};
15use std::collections::{hash_map, HashMap};
16use std::sync::{Arc, Mutex};
17use std::net::SocketAddr;
18use bytes::{BytesMut, Bytes};
19use futures::{Async, AsyncSink, Stream, Sink};
20use future_utils::mpsc::{self, UnboundedReceiver, UnboundedSender};
21use tokio::net::UdpSocket;
22use void::{ResultVoidExt};
23
24pub struct SharedUdpSocket {
26 inner: Arc<SharedUdpSocketInner>,
27}
28
29pub struct IncomingEndpoints {
30 inner: Arc<SharedUdpSocketInner>,
31 incoming_rx: UnboundedReceiver<UdpEndpoint>,
32 buffer: BytesMut,
33}
34
35pub struct UdpEndpoint {
40 inner: Arc<SharedUdpSocketInner>,
41 incoming_rx: UnboundedReceiver<Bytes>,
42 addr: SocketAddr,
43 buffer: BytesMut,
44}
45
46struct SharedUdpSocketInner {
47 socket: Mutex<Option<UdpSocket>>,
48 endpoints: Mutex<HashMap<SocketAddr, UnboundedSender<Bytes>>>,
49 incoming_tx: Mutex<Option<UnboundedSender<UdpEndpoint>>>,
50}
51
52impl SharedUdpSocket {
53 pub fn share(socket: UdpSocket) -> (SharedUdpSocket, IncomingEndpoints) {
55 trace!("creating shared udp socket on address {:?}", socket.local_addr());
56 let (tx, rx) = mpsc::unbounded();
57 let inner = SharedUdpSocketInner {
58 socket: Mutex::new(Some(socket)),
59 endpoints: Mutex::new(HashMap::new()),
60 incoming_tx: Mutex::new(Some(tx)),
61 };
62 let inner = Arc::new(inner);
63 let shared = SharedUdpSocket {
64 inner: inner.clone(),
65 };
66 let incoming = IncomingEndpoints {
67 inner,
68 incoming_rx: rx,
69 buffer: BytesMut::new(),
70 };
71 (shared, incoming)
72 }
73
74 pub fn endpoint(&self, addr: SocketAddr) -> UdpEndpoint {
79 let (tx, endpoint) = endpoint_new(&self.inner, addr);
80 let mut endpoints = unwrap!(self.inner.endpoints.lock());
81 let _ = endpoints.insert(addr, tx);
82 endpoint
83 }
84
85 pub fn try_endpoint(&self, addr: SocketAddr) -> Option<UdpEndpoint> {
90 let mut endpoints = unwrap!(self.inner.endpoints.lock());
91 match endpoints.entry(addr) {
92 hash_map::Entry::Occupied(..) => None,
93 hash_map::Entry::Vacant(ve) => {
94 let (tx, endpoint) = endpoint_new(&self.inner, addr);
95 let _ = ve.insert(tx);
96 Some(endpoint)
97 },
98 }
99 }
100
101 pub fn steal(self) -> Option<UdpSocket> {
104 let mut socket_opt = unwrap!(self.inner.socket.lock());
105 socket_opt.take()
106 }
107
108 pub fn local_addr(&self) -> io::Result<SocketAddr> {
109 self.inner.local_addr()
110 }
111}
112
113fn pump(inner: &Arc<SharedUdpSocketInner>, buffer: &mut BytesMut) -> io::Result<()> {
114 let mut socket_opt = unwrap!(inner.socket.lock());
115 let socket = match *socket_opt {
116 Some(ref mut socket) => socket,
117 None => return Ok(()),
118 };
119
120 loop {
121 let min_capacity = 64 * 1024 + 1;
122 let capacity = buffer.capacity();
123 if capacity < min_capacity {
124 buffer.reserve(min_capacity - capacity);
125 }
126 let capacity = buffer.capacity();
127 unsafe {
128 buffer.set_len(capacity)
129 }
130 match socket.poll_recv_from(&mut *buffer) {
131 Ok(Async::Ready((n, addr))) => {
132 if n == buffer.len() {
133 return Err(io::Error::new(
134 io::ErrorKind::Other,
135 "failed to recv entire dgram",
136 ));
137 }
138 let data = buffer.split_to(n).freeze();
139 let mut endpoints = unwrap!(inner.endpoints.lock());
140 let drop_after_unlock = match endpoints.entry(addr) {
141 hash_map::Entry::Occupied(mut oe) => {
142 match oe.get().unbounded_send(data) {
143 Ok(()) => None,
144 Err(send_error) => {
145 if let Some(ref incoming_tx) = *unwrap!(inner.incoming_tx.lock()) {
146 let (tx, endpoint) = endpoint_new(inner, addr);
147
148 unwrap!(tx.unbounded_send(send_error.into_inner()));
149 let _ = mem::replace(oe.get_mut(), tx);
150 match incoming_tx.unbounded_send(endpoint) {
151 Ok(()) => None,
152 Err(send_error) => Some(send_error.into_inner()),
153 }
154 } else {
155 None
156 }
157 },
158 }
159 },
160 hash_map::Entry::Vacant(ve) => {
161 if let Some(ref incoming_tx) = *unwrap!(inner.incoming_tx.lock()) {
162 let (tx, endpoint) = endpoint_new(inner, addr);
163
164 unwrap!(tx.unbounded_send(data));
165 ve.insert(tx);
166 match incoming_tx.unbounded_send(endpoint) {
167 Ok(()) => None,
168 Err(send_error) => Some(send_error.into_inner()),
169 }
170 } else {
171 None
172 }
173 },
174 };
175 drop(endpoints);
176 drop(drop_after_unlock);
177 },
178 Ok(Async::NotReady) => return Ok(()),
179 Err(e) => {
180 match e.kind() {
181 io::ErrorKind::WouldBlock => return Ok(()),
182 io::ErrorKind::ConnectionReset => continue,
183 _ => return Err(e),
184 }
185 },
186 }
187 }
188}
189
190fn endpoint_new(inner: &Arc<SharedUdpSocketInner>, addr: SocketAddr) -> (UnboundedSender<Bytes>, UdpEndpoint) {
191 let (tx, rx) = mpsc::unbounded();
192 let inner = inner.clone();
193 let endpoint = UdpEndpoint {
194 inner: inner,
195 incoming_rx: rx,
196 addr: addr,
197 buffer: BytesMut::new(),
198 };
199 (tx, endpoint)
200}
201
202impl UdpEndpoint {
203 pub fn remote_addr(&self) -> SocketAddr {
205 self.addr
206 }
207
208 pub fn steal(self) -> Option<UdpSocket> {
211 let mut socket_opt = unwrap!(self.inner.socket.lock());
212 socket_opt.take()
213 }
214
215 pub fn local_addr(&self) -> io::Result<SocketAddr> {
216 self.inner.local_addr()
217 }
218}
219
220impl SharedUdpSocketInner {
221 pub fn local_addr(&self) -> io::Result<SocketAddr> {
222 let socket_opt = unwrap!(self.socket.lock());
223 match *socket_opt {
224 Some(ref socket) => socket.local_addr(),
225 None => Err(io::Error::new(io::ErrorKind::Other, "socket has been stolen")),
226 }
227 }
228}
229
230impl Stream for IncomingEndpoints {
231 type Item = UdpEndpoint;
232 type Error = io::Error;
233
234 fn poll(&mut self) -> io::Result<Async<Option<UdpEndpoint>>> {
235 pump(&self.inner, &mut self.buffer)?;
236
237 Ok(self.incoming_rx.poll().void_unwrap())
238 }
239}
240
241impl Stream for UdpEndpoint {
242 type Item = Bytes;
243 type Error = io::Error;
244
245 fn poll(&mut self) -> io::Result<Async<Option<Bytes>>> {
246 pump(&self.inner, &mut self.buffer)?;
247
248 Ok(self.incoming_rx.poll().void_unwrap())
249 }
250}
251
252impl Sink for UdpEndpoint {
253 type SinkItem = Bytes;
254 type SinkError = io::Error;
255
256 fn start_send(&mut self, item: Bytes) -> io::Result<AsyncSink<Bytes>> {
257 let mut socket_opt = unwrap!(self.inner.socket.lock());
258 let socket = match *socket_opt {
259 Some(ref mut socket) => socket,
260 None => return Err(io::ErrorKind::NotConnected.into()),
261 };
262
263 match socket.poll_send_to(&item, &self.addr) {
264 Ok(Async::Ready(n)) => {
265 if n != item.len() {
266 return Err(io::Error::new(
267 io::ErrorKind::Other,
268 "failed to send entire dgram",
269 ));
270 }
271 return Ok(AsyncSink::Ready);
272 },
273 Ok(Async::NotReady) => return Ok(AsyncSink::NotReady(item)),
274 Err(e) => {
275 if e.kind() == io::ErrorKind::WouldBlock {
276 return Ok(AsyncSink::NotReady(item));
277 }
278 return Err(e);
279 },
280 }
281 }
282
283 fn poll_complete(&mut self) -> io::Result<Async<()>> {
284 Ok(Async::Ready(()))
285 }
286}
287
288impl Drop for SharedUdpSocket {
289 fn drop(&mut self) {
290 let mut incoming_tx = unwrap!(self.inner.incoming_tx.lock());
291 *incoming_tx = None;
292 }
293}
294
295impl Drop for UdpEndpoint {
296 fn drop(&mut self) {
297 let mut endpoints = unwrap!(self.inner.endpoints.lock());
298 let _ = endpoints.remove(&self.addr);
299 }
300}
301
302#[cfg(test)]
303mod test {
304 use super::*;
305 use futures::Future;
306
307 #[test]
308 fn test() {
309 let sock0 = unwrap!(UdpSocket::bind(&addr!("127.0.0.1:0")));
310 let addr0 = unwrap!(sock0.local_addr());
311 let sock1 = unwrap!(UdpSocket::bind(&addr!("127.0.0.1:0")));
312 let addr1 = unwrap!(sock1.local_addr());
313
314 let shared = unwrap!(UdpSocket::bind(&addr!("127.0.0.1:0")));
315 let shared_addr = unwrap!(shared.local_addr());
316 let (_shared, incoming) = SharedUdpSocket::share(shared);
317
318 tokio::run({
319 sock0
320 .send_dgram(b"qqqq", &shared_addr)
321 .map_err(|e| panic!("{}", e))
322 .and_then(move |(sock0, _)| {
323 incoming
324 .into_future()
325 .map_err(|(e, _)| panic!("{}", e))
326 .and_then(move |(opt, shared)| {
327 let endpoint_0 = unwrap!(opt);
328 assert_eq!(endpoint_0.remote_addr(), addr0);
329
330 endpoint_0
331 .into_future()
332 .map_err(|(e, _)| panic!("{}", e))
333 .and_then(move |(opt, endpoint_0)| {
334 let data = unwrap!(opt);
335 assert_eq!(&data[..], b"qqqq");
336
337 sock0
338 .send_dgram(b"wwww", &shared_addr)
339 .map_err(|e| panic!("{}", e))
340 .and_then(move |(sock0, _)| {
341 sock1
342 .send_dgram(b"eeee", &shared_addr)
343 .map_err(|e| panic!("{}", e))
344 .and_then(move |_sock1| {
345 shared
346 .into_future()
347 .map_err(|(e, _)| panic!("{}", e))
348 .and_then(move |(opt, shared)| {
349 let endpoint_1 = unwrap!(opt);
350 assert_eq!(endpoint_1.remote_addr(), addr1);
351 drop(shared);
352
353 endpoint_1
354 .into_future()
355 .map_err(|(e, _)| panic!("{}", e))
356 .and_then(move |(opt, _endpoint_1)| {
357 let data = unwrap!(opt);
358 assert_eq!(&data[..], b"eeee");
359
360 endpoint_0
361 .into_future()
362 .map_err(|(e, _)| panic!("{}", e))
363 .and_then(move |(opt, endpoint_0)| {
364 let data = unwrap!(opt);
365 assert_eq!(&data[..], b"wwww");
366
367 endpoint_0
368 .send(Bytes::from(&b"rrrr"[..]))
369 .map_err(|e| panic!("{}", e))
370 .and_then(move |endpoint_0| {
371 let buff = [0; 10];
372
373 sock0
374 .recv_dgram(buff)
375 .map_err(|e| panic!("{}", e))
376 .map(move |(_sock0, data, len, addr)| {
377 assert_eq!(addr, shared_addr);
378 assert_eq!(&data[..len], b"rrrr");
379 assert!(endpoint_0.steal().is_some());
380 })
381 })
382 })
383 })
384 })
385 })
386 })
387 })
388 })
389 })
390 });
391 }
392}
393