xenet_socket/
lib.rs

1#[cfg(not(target_os = "windows"))]
2mod unix;
3#[cfg(not(target_os = "windows"))]
4pub use unix::*;
5
6#[cfg(target_os = "windows")]
7mod windows;
8#[cfg(target_os = "windows")]
9pub use windows::*;
10
11use async_io::Async;
12use socket2::{Domain, SockAddr, Socket as SystemSocket, Type};
13use std::io;
14use std::mem::MaybeUninit;
15use std::net::{Shutdown, SocketAddr};
16use std::sync::Arc;
17use std::time::Duration;
18
19use xenet_packet::ip::IpNextLevelProtocol;
20
21/// IP version. IPv4 or IPv6.
22#[derive(Clone, Debug)]
23pub enum IpVersion {
24    V4,
25    V6,
26}
27
28impl IpVersion {
29    /// IP Version number as u8.
30    pub fn version_u8(&self) -> u8 {
31        match self {
32            IpVersion::V4 => 4,
33            IpVersion::V6 => 6,
34        }
35    }
36    /// Return true if IP version is IPv4.
37    pub fn is_ipv4(&self) -> bool {
38        match self {
39            IpVersion::V4 => true,
40            IpVersion::V6 => false,
41        }
42    }
43    /// Return true if IP version is IPv6.
44    pub fn is_ipv6(&self) -> bool {
45        match self {
46            IpVersion::V4 => false,
47            IpVersion::V6 => true,
48        }
49    }
50    pub(crate) fn to_domain(&self) -> Domain {
51        match self {
52            IpVersion::V4 => Domain::IPV4,
53            IpVersion::V6 => Domain::IPV6,
54        }
55    }
56}
57
58/// Socket type
59#[derive(Clone, Debug)]
60pub enum SocketType {
61    /// Raw socket
62    Raw,
63    /// Datagram socket. Usualy used for UDP.
64    Datagram,
65    /// Stream socket. Used for TCP.
66    Stream,
67}
68
69impl SocketType {
70    pub(crate) fn to_type(&self) -> Type {
71        match self {
72            SocketType::Raw => Type::RAW,
73            SocketType::Datagram => Type::DGRAM,
74            SocketType::Stream => Type::STREAM,
75        }
76    }
77}
78
79/// Socket option.
80#[derive(Clone, Debug)]
81pub struct SocketOption {
82    /// IP version
83    pub ip_version: IpVersion,
84    /// Socket type
85    pub socket_type: SocketType,
86    /// Protocol. TCP, UDP, ICMP, etc.
87    pub protocol: Option<IpNextLevelProtocol>,
88    /// Timeout
89    pub timeout: Option<u64>,
90    /// TTL or Hop Limit
91    pub ttl: Option<u32>,
92    /// Non-blocking mode
93    pub non_blocking: bool,
94}
95
96impl SocketOption {
97    /// Constructs a new SocketOption.
98    pub fn new(
99        ip_version: IpVersion,
100        socket_type: SocketType,
101        protocol: Option<IpNextLevelProtocol>,
102    ) -> SocketOption {
103        SocketOption {
104            ip_version,
105            socket_type,
106            protocol,
107            timeout: None,
108            ttl: None,
109            non_blocking: false,
110        }
111    }
112    /// Check socket option.
113    /// Return Ok(()) if socket option is valid.
114    pub fn is_valid(&self) -> Result<(), String> {
115        check_socket_option(self.clone())
116    }
117}
118
119/// Async socket. Provides cross-platform async adapter for system’s socket.
120#[derive(Clone, Debug)]
121pub struct AsyncSocket {
122    inner: Arc<Async<SystemSocket>>,
123}
124
125impl AsyncSocket {
126    /// Constructs a new AsyncSocket.
127    pub fn new(socket_option: SocketOption) -> io::Result<AsyncSocket> {
128        let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
129            SystemSocket::new(
130                socket_option.ip_version.to_domain(),
131                socket_option.socket_type.to_type(),
132                Some(to_socket_protocol(protocol)),
133            )?
134        } else {
135            SystemSocket::new(
136                socket_option.ip_version.to_domain(),
137                socket_option.socket_type.to_type(),
138                None,
139            )?
140        };
141        socket.set_nonblocking(true)?;
142        Ok(AsyncSocket {
143            inner: Arc::new(Async::new(socket)?),
144        })
145    }
146    /// Send packet.
147    pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
148        loop {
149            self.inner.writable().await?;
150            match self.inner.write_with(|inner| inner.send(buf)).await {
151                Ok(n) => return Ok(n),
152                Err(_) => continue,
153            }
154        }
155    }
156    /// Send packet to target.
157    pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
158        let target: SockAddr = SockAddr::from(target);
159        loop {
160            self.inner.writable().await?;
161            match self
162                .inner
163                .write_with(|inner| inner.send_to(buf, &target))
164                .await
165            {
166                Ok(n) => return Ok(n),
167                Err(_) => continue,
168            }
169        }
170    }
171    /// Receive packet.
172    pub async fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
173        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
174        loop {
175            self.inner.readable().await?;
176            match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
177                Ok(result) => return Ok(result),
178                Err(_) => continue,
179            }
180        }
181    }
182    /// Receive packet with sender address.
183    pub async fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
184        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
185        loop {
186            self.inner.readable().await?;
187            match self
188                .inner
189                .read_with(|inner| inner.recv_from(recv_buf))
190                .await
191            {
192                Ok(result) => {
193                    let (n, addr) = result;
194                    match addr.as_socket() {
195                        Some(addr) => return Ok((n, addr)),
196                        None => continue,
197                    }
198                }
199                Err(_) => continue,
200            }
201        }
202    }
203    /// Write data to the socket and send to the target.
204    /// Return how many bytes were written.
205    pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
206        loop {
207            self.inner.writable().await?;
208            match self.inner.write_with(|inner| inner.send(buf)).await {
209                Ok(n) => return Ok(n),
210                Err(_) => continue,
211            }
212        }
213    }
214    /// Read data from the socket.
215    /// Return how many bytes were read.
216    pub async fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
217        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
218        loop {
219            self.inner.readable().await?;
220            match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
221                Ok(result) => return Ok(result),
222                Err(_) => continue,
223            }
224        }
225    }
226    /// Bind socket to address.
227    pub async fn bind(&self, addr: SocketAddr) -> io::Result<()> {
228        let addr: SockAddr = SockAddr::from(addr);
229        self.inner.writable().await?;
230        self.inner.write_with(|inner| inner.bind(&addr)).await
231    }
232    /// Set receive timeout.
233    pub async fn set_receive_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
234        self.inner.writable().await?;
235        self.inner
236            .write_with(|inner| inner.set_read_timeout(timeout))
237            .await
238    }
239    /// Set TTL or Hop Limit.
240    pub async fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
241        self.inner.writable().await?;
242        match ip_version {
243            IpVersion::V4 => self.inner.write_with(|inner| inner.set_ttl(ttl)).await,
244            IpVersion::V6 => {
245                self.inner
246                    .write_with(|inner| inner.set_unicast_hops_v6(ttl))
247                    .await
248            }
249        }
250    }
251    /// Initiate TCP connection.
252    pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
253        let addr: SockAddr = SockAddr::from(addr);
254        self.inner.writable().await?;
255        self.inner.write_with(|inner| inner.connect(&addr)).await
256    }
257    /// Shutdown TCP connection.
258    pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> {
259        self.inner.writable().await?;
260        self.inner.write_with(|inner| inner.shutdown(how)).await
261    }
262    /// Listen TCP connection.
263    pub async fn listen(&self, backlog: i32) -> io::Result<()> {
264        self.inner.writable().await?;
265        self.inner.write_with(|inner| inner.listen(backlog)).await
266    }
267    /// Accept TCP connection.
268    pub async fn accept(&self) -> io::Result<(AsyncSocket, SocketAddr)> {
269        self.inner.readable().await?;
270        match self.inner.read_with(|inner| inner.accept()).await {
271            Ok((socket, addr)) => {
272                let socket = AsyncSocket {
273                    inner: Arc::new(Async::new(socket)?),
274                };
275                Ok((socket, addr.as_socket().unwrap()))
276            }
277            Err(e) => Err(e),
278        }
279    }
280    /// Get peer address.
281    pub async fn peer_addr(&self) -> io::Result<SocketAddr> {
282        self.inner.writable().await?;
283        match self.inner.read_with(|inner| inner.peer_addr()).await {
284            Ok(addr) => Ok(addr.as_socket().unwrap()),
285            Err(e) => Err(e),
286        }
287    }
288    /// Get local address.
289    pub async fn local_addr(&self) -> io::Result<SocketAddr> {
290        self.inner.writable().await?;
291        match self.inner.read_with(|inner| inner.local_addr()).await {
292            Ok(addr) => Ok(addr.as_socket().unwrap()),
293            Err(e) => Err(e),
294        }
295    }
296    /// Initiate a connection on this socket to the specified address, only only waiting for a certain period of time for the connection to be established.
297    /// The non-blocking state of the socket is overridden by this function.
298    pub async fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
299        let addr: SockAddr = SockAddr::from(*addr);
300        self.inner.writable().await?;
301        self.inner
302            .write_with(|inner| inner.connect_timeout(&addr, timeout))
303            .await
304    }
305    /// Set the value of the `SO_BROADCAST` option for this socket.
306    ///
307    /// When enabled, this socket is allowed to send packets to a broadcast address.
308    pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
309        self.inner.writable().await?;
310        self.inner
311            .write_with(|inner| inner.set_nonblocking(nonblocking))
312            .await
313    }
314    /// Set the value of the `SO_BROADCAST` option for this socket.
315    ///
316    /// When enabled, this socket is allowed to send packets to a broadcast address.
317    pub async fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
318        self.inner.writable().await?;
319        self.inner
320            .write_with(|inner| inner.set_broadcast(broadcast))
321            .await
322    }
323    /// Get the value of the `SO_ERROR` option on this socket.
324    pub async fn get_error(&self) -> io::Result<Option<io::Error>> {
325        self.inner.readable().await?;
326        self.inner.read_with(|inner| inner.take_error()).await
327    }
328    /// Set value for the `SO_KEEPALIVE` option on this socket.
329    ///
330    /// Enable sending of keep-alive messages on connection-oriented sockets.
331    pub async fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
332        self.inner.writable().await?;
333        self.inner
334            .write_with(|inner| inner.set_keepalive(keepalive))
335            .await
336    }
337    /// Set value for the `SO_RCVBUF` option on this socket.
338    ///
339    /// Changes the size of the operating system's receive buffer associated with the socket.
340    pub async fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
341        self.inner.writable().await?;
342        self.inner
343            .write_with(|inner| inner.set_recv_buffer_size(size))
344            .await
345    }
346    /// Set value for the `SO_REUSEADDR` option on this socket.
347    ///
348    /// This indicates that futher calls to `bind` may allow reuse of local addresses.
349    pub async fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
350        self.inner.writable().await?;
351        self.inner
352            .write_with(|inner| inner.set_reuse_address(reuse))
353            .await
354    }
355    /// Set value for the `SO_SNDBUF` option on this socket.
356    ///
357    /// Changes the size of the operating system's send buffer associated with the socket.
358    pub async fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
359        self.inner.writable().await?;
360        self.inner
361            .write_with(|inner| inner.set_send_buffer_size(size))
362            .await
363    }
364    /// Set value for the `SO_SNDTIMEO` option on this socket.
365    ///
366    /// If `timeout` is `None`, then `write` and `send` calls will block indefinitely.
367    pub async fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
368        self.inner.writable().await?;
369        self.inner
370            .write_with(|inner| inner.set_write_timeout(duration))
371            .await
372    }
373    /// Set the value of the `TCP_NODELAY` option on this socket.
374    ///
375    /// If set, segments are always sent as soon as possible, even if there is only a small amount of data.
376    pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
377        self.inner.writable().await?;
378        self.inner
379            .write_with(|inner| inner.set_nodelay(nodelay))
380            .await
381    }
382}
383
384/// Socket. Provides cross-platform adapter for system’s socket.
385#[derive(Clone, Debug)]
386pub struct Socket {
387    inner: Arc<SystemSocket>,
388}
389
390impl Socket {
391    /// Constructs a new Socket.
392    pub fn new(socket_option: SocketOption) -> io::Result<Socket> {
393        let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
394            SystemSocket::new(
395                socket_option.ip_version.to_domain(),
396                socket_option.socket_type.to_type(),
397                Some(to_socket_protocol(protocol)),
398            )?
399        } else {
400            SystemSocket::new(
401                socket_option.ip_version.to_domain(),
402                socket_option.socket_type.to_type(),
403                None,
404            )?
405        };
406        if socket_option.non_blocking {
407            socket.set_nonblocking(true)?;
408        }
409        Ok(Socket {
410            inner: Arc::new(socket),
411        })
412    }
413    /// Send packet to target.
414    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
415        let target: SockAddr = SockAddr::from(target);
416        match self.inner.send_to(buf, &target) {
417            Ok(n) => Ok(n),
418            Err(e) => Err(e),
419        }
420    }
421    /// Receive packet.
422    pub fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
423        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
424        match self.inner.recv(recv_buf) {
425            Ok(result) => Ok(result),
426            Err(e) => Err(e),
427        }
428    }
429    /// Receive packet with sender address.
430    pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
431        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
432        match self.inner.recv_from(recv_buf) {
433            Ok(result) => {
434                let (n, addr) = result;
435                match addr.as_socket() {
436                    Some(addr) => return Ok((n, addr)),
437                    None => {
438                        return Err(io::Error::new(
439                            io::ErrorKind::Other,
440                            "Invalid socket address",
441                        ))
442                    }
443                }
444            }
445            Err(e) => Err(e),
446        }
447    }
448    /// Write data to the socket and send to the target.
449    /// Return how many bytes were written.
450    pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
451        match self.inner.send(buf) {
452            Ok(n) => Ok(n),
453            Err(e) => Err(e),
454        }
455    }
456    /// Read data from the socket.
457    /// Return how many bytes were read.
458    pub fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
459        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
460        match self.inner.recv(recv_buf) {
461            Ok(result) => Ok(result),
462            Err(e) => Err(e),
463        }
464    }
465    /// Bind socket to address.
466    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
467        let addr: SockAddr = SockAddr::from(addr);
468        self.inner.bind(&addr)
469    }
470    /// Set receive timeout.
471    pub fn set_receive_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
472        self.inner.set_read_timeout(timeout)
473    }
474    /// Set TTL or Hop Limit.
475    pub fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
476        match ip_version {
477            IpVersion::V4 => self.inner.set_ttl(ttl),
478            IpVersion::V6 => self.inner.set_unicast_hops_v6(ttl),
479        }
480    }
481    /// Initiate TCP connection.
482    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
483        let addr: SockAddr = SockAddr::from(addr);
484        self.inner.connect(&addr)
485    }
486    /// Shutdown TCP connection.
487    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
488        self.inner.shutdown(how)
489    }
490    /// Listen TCP connection.
491    pub fn listen(&self, backlog: i32) -> io::Result<()> {
492        self.inner.listen(backlog)
493    }
494    /// Accept TCP connection.
495    pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
496        match self.inner.accept() {
497            Ok((socket, addr)) => Ok((
498                Socket {
499                    inner: Arc::new(socket),
500                },
501                addr.as_socket().unwrap(),
502            )),
503            Err(e) => Err(e),
504        }
505    }
506    /// Get peer address.
507    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
508        match self.inner.peer_addr() {
509            Ok(addr) => Ok(addr.as_socket().unwrap()),
510            Err(e) => Err(e),
511        }
512    }
513    /// Get local address.
514    pub fn local_addr(&self) -> io::Result<SocketAddr> {
515        match self.inner.local_addr() {
516            Ok(addr) => Ok(addr.as_socket().unwrap()),
517            Err(e) => Err(e),
518        }
519    }
520    /// Initiate a connection on this socket to the specified address, only only waiting for a certain period of time for the connection to be established.
521    /// The non-blocking state of the socket is overridden by this function.
522    pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
523        let addr: SockAddr = SockAddr::from(*addr);
524        self.inner.connect_timeout(&addr, timeout)
525    }
526    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
527        self.inner.set_nonblocking(nonblocking)
528    }
529    /// Set the value of the `SO_BROADCAST` option for this socket.
530    ///
531    /// When enabled, this socket is allowed to send packets to a broadcast address.
532    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
533        self.inner.set_broadcast(broadcast)
534    }
535    /// Get the value of the `SO_ERROR` option on this socket.
536    pub fn get_error(&self) -> io::Result<Option<io::Error>> {
537        self.inner.take_error()
538    }
539    /// Set value for the `SO_KEEPALIVE` option on this socket.
540    ///
541    /// Enable sending of keep-alive messages on connection-oriented sockets.
542    pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
543        self.inner.set_keepalive(keepalive)
544    }
545    /// Set value for the `SO_RCVBUF` option on this socket.
546    ///
547    /// Changes the size of the operating system's receive buffer associated with the socket.
548    pub fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
549        self.inner.set_recv_buffer_size(size)
550    }
551    /// Set value for the `SO_REUSEADDR` option on this socket.
552    ///
553    /// This indicates that futher calls to `bind` may allow reuse of local addresses.
554    pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
555        self.inner.set_reuse_address(reuse)
556    }
557    /// Set value for the `SO_SNDBUF` option on this socket.
558    ///
559    /// Changes the size of the operating system's send buffer associated with the socket.
560    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
561        self.inner.set_send_buffer_size(size)
562    }
563    /// Set value for the `SO_SNDTIMEO` option on this socket.
564    ///
565    /// If `timeout` is `None`, then `write` and `send` calls will block indefinitely.
566    pub fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
567        self.inner.set_write_timeout(duration)
568    }
569    /// Set the value of the `TCP_NODELAY` option on this socket.
570    ///
571    /// If set, segments are always sent as soon as possible, even if there is only a small amount of data.
572    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
573        self.inner.set_nodelay(nodelay)
574    }
575}
576
577fn to_socket_protocol(protocol: IpNextLevelProtocol) -> socket2::Protocol {
578    match protocol {
579        IpNextLevelProtocol::Tcp => socket2::Protocol::TCP,
580        IpNextLevelProtocol::Udp => socket2::Protocol::UDP,
581        IpNextLevelProtocol::Icmp => socket2::Protocol::ICMPV4,
582        IpNextLevelProtocol::Icmpv6 => socket2::Protocol::ICMPV6,
583        _ => socket2::Protocol::TCP,
584    }
585}