1#[cfg(not(target_os = "windows"))]
2mod unix;
3#[cfg(not(target_os = "windows"))]
4pub use unix::*;
5
6#[cfg(target_os = "windows")]
7mod windows;
8#[cfg(target_os = "windows")]
9pub use windows::*;
10
11use async_io::Async;
12use socket2::{Domain, SockAddr, Socket as SystemSocket, Type};
13use std::io;
14use std::mem::MaybeUninit;
15use std::net::{Shutdown, SocketAddr};
16use std::sync::Arc;
17use std::time::Duration;
18
19use xenet_packet::ip::IpNextLevelProtocol;
20
21#[derive(Clone, Debug)]
23pub enum IpVersion {
24 V4,
25 V6,
26}
27
28impl IpVersion {
29 pub fn version_u8(&self) -> u8 {
31 match self {
32 IpVersion::V4 => 4,
33 IpVersion::V6 => 6,
34 }
35 }
36 pub fn is_ipv4(&self) -> bool {
38 match self {
39 IpVersion::V4 => true,
40 IpVersion::V6 => false,
41 }
42 }
43 pub fn is_ipv6(&self) -> bool {
45 match self {
46 IpVersion::V4 => false,
47 IpVersion::V6 => true,
48 }
49 }
50 pub(crate) fn to_domain(&self) -> Domain {
51 match self {
52 IpVersion::V4 => Domain::IPV4,
53 IpVersion::V6 => Domain::IPV6,
54 }
55 }
56}
57
58#[derive(Clone, Debug)]
60pub enum SocketType {
61 Raw,
63 Datagram,
65 Stream,
67}
68
69impl SocketType {
70 pub(crate) fn to_type(&self) -> Type {
71 match self {
72 SocketType::Raw => Type::RAW,
73 SocketType::Datagram => Type::DGRAM,
74 SocketType::Stream => Type::STREAM,
75 }
76 }
77}
78
79#[derive(Clone, Debug)]
81pub struct SocketOption {
82 pub ip_version: IpVersion,
84 pub socket_type: SocketType,
86 pub protocol: Option<IpNextLevelProtocol>,
88 pub timeout: Option<u64>,
90 pub ttl: Option<u32>,
92 pub non_blocking: bool,
94}
95
96impl SocketOption {
97 pub fn new(
99 ip_version: IpVersion,
100 socket_type: SocketType,
101 protocol: Option<IpNextLevelProtocol>,
102 ) -> SocketOption {
103 SocketOption {
104 ip_version,
105 socket_type,
106 protocol,
107 timeout: None,
108 ttl: None,
109 non_blocking: false,
110 }
111 }
112 pub fn is_valid(&self) -> Result<(), String> {
115 check_socket_option(self.clone())
116 }
117}
118
119#[derive(Clone, Debug)]
121pub struct AsyncSocket {
122 inner: Arc<Async<SystemSocket>>,
123}
124
125impl AsyncSocket {
126 pub fn new(socket_option: SocketOption) -> io::Result<AsyncSocket> {
128 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
129 SystemSocket::new(
130 socket_option.ip_version.to_domain(),
131 socket_option.socket_type.to_type(),
132 Some(to_socket_protocol(protocol)),
133 )?
134 } else {
135 SystemSocket::new(
136 socket_option.ip_version.to_domain(),
137 socket_option.socket_type.to_type(),
138 None,
139 )?
140 };
141 socket.set_nonblocking(true)?;
142 Ok(AsyncSocket {
143 inner: Arc::new(Async::new(socket)?),
144 })
145 }
146 pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
148 loop {
149 self.inner.writable().await?;
150 match self.inner.write_with(|inner| inner.send(buf)).await {
151 Ok(n) => return Ok(n),
152 Err(_) => continue,
153 }
154 }
155 }
156 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
158 let target: SockAddr = SockAddr::from(target);
159 loop {
160 self.inner.writable().await?;
161 match self
162 .inner
163 .write_with(|inner| inner.send_to(buf, &target))
164 .await
165 {
166 Ok(n) => return Ok(n),
167 Err(_) => continue,
168 }
169 }
170 }
171 pub async fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
173 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
174 loop {
175 self.inner.readable().await?;
176 match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
177 Ok(result) => return Ok(result),
178 Err(_) => continue,
179 }
180 }
181 }
182 pub async fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
184 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
185 loop {
186 self.inner.readable().await?;
187 match self
188 .inner
189 .read_with(|inner| inner.recv_from(recv_buf))
190 .await
191 {
192 Ok(result) => {
193 let (n, addr) = result;
194 match addr.as_socket() {
195 Some(addr) => return Ok((n, addr)),
196 None => continue,
197 }
198 }
199 Err(_) => continue,
200 }
201 }
202 }
203 pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
206 loop {
207 self.inner.writable().await?;
208 match self.inner.write_with(|inner| inner.send(buf)).await {
209 Ok(n) => return Ok(n),
210 Err(_) => continue,
211 }
212 }
213 }
214 pub async fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
217 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
218 loop {
219 self.inner.readable().await?;
220 match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
221 Ok(result) => return Ok(result),
222 Err(_) => continue,
223 }
224 }
225 }
226 pub async fn bind(&self, addr: SocketAddr) -> io::Result<()> {
228 let addr: SockAddr = SockAddr::from(addr);
229 self.inner.writable().await?;
230 self.inner.write_with(|inner| inner.bind(&addr)).await
231 }
232 pub async fn set_receive_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
234 self.inner.writable().await?;
235 self.inner
236 .write_with(|inner| inner.set_read_timeout(timeout))
237 .await
238 }
239 pub async fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
241 self.inner.writable().await?;
242 match ip_version {
243 IpVersion::V4 => self.inner.write_with(|inner| inner.set_ttl(ttl)).await,
244 IpVersion::V6 => {
245 self.inner
246 .write_with(|inner| inner.set_unicast_hops_v6(ttl))
247 .await
248 }
249 }
250 }
251 pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
253 let addr: SockAddr = SockAddr::from(addr);
254 self.inner.writable().await?;
255 self.inner.write_with(|inner| inner.connect(&addr)).await
256 }
257 pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> {
259 self.inner.writable().await?;
260 self.inner.write_with(|inner| inner.shutdown(how)).await
261 }
262 pub async fn listen(&self, backlog: i32) -> io::Result<()> {
264 self.inner.writable().await?;
265 self.inner.write_with(|inner| inner.listen(backlog)).await
266 }
267 pub async fn accept(&self) -> io::Result<(AsyncSocket, SocketAddr)> {
269 self.inner.readable().await?;
270 match self.inner.read_with(|inner| inner.accept()).await {
271 Ok((socket, addr)) => {
272 let socket = AsyncSocket {
273 inner: Arc::new(Async::new(socket)?),
274 };
275 Ok((socket, addr.as_socket().unwrap()))
276 }
277 Err(e) => Err(e),
278 }
279 }
280 pub async fn peer_addr(&self) -> io::Result<SocketAddr> {
282 self.inner.writable().await?;
283 match self.inner.read_with(|inner| inner.peer_addr()).await {
284 Ok(addr) => Ok(addr.as_socket().unwrap()),
285 Err(e) => Err(e),
286 }
287 }
288 pub async fn local_addr(&self) -> io::Result<SocketAddr> {
290 self.inner.writable().await?;
291 match self.inner.read_with(|inner| inner.local_addr()).await {
292 Ok(addr) => Ok(addr.as_socket().unwrap()),
293 Err(e) => Err(e),
294 }
295 }
296 pub async fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
299 let addr: SockAddr = SockAddr::from(*addr);
300 self.inner.writable().await?;
301 self.inner
302 .write_with(|inner| inner.connect_timeout(&addr, timeout))
303 .await
304 }
305 pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
309 self.inner.writable().await?;
310 self.inner
311 .write_with(|inner| inner.set_nonblocking(nonblocking))
312 .await
313 }
314 pub async fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
318 self.inner.writable().await?;
319 self.inner
320 .write_with(|inner| inner.set_broadcast(broadcast))
321 .await
322 }
323 pub async fn get_error(&self) -> io::Result<Option<io::Error>> {
325 self.inner.readable().await?;
326 self.inner.read_with(|inner| inner.take_error()).await
327 }
328 pub async fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
332 self.inner.writable().await?;
333 self.inner
334 .write_with(|inner| inner.set_keepalive(keepalive))
335 .await
336 }
337 pub async fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
341 self.inner.writable().await?;
342 self.inner
343 .write_with(|inner| inner.set_recv_buffer_size(size))
344 .await
345 }
346 pub async fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
350 self.inner.writable().await?;
351 self.inner
352 .write_with(|inner| inner.set_reuse_address(reuse))
353 .await
354 }
355 pub async fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
359 self.inner.writable().await?;
360 self.inner
361 .write_with(|inner| inner.set_send_buffer_size(size))
362 .await
363 }
364 pub async fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
368 self.inner.writable().await?;
369 self.inner
370 .write_with(|inner| inner.set_write_timeout(duration))
371 .await
372 }
373 pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
377 self.inner.writable().await?;
378 self.inner
379 .write_with(|inner| inner.set_nodelay(nodelay))
380 .await
381 }
382}
383
384#[derive(Clone, Debug)]
386pub struct Socket {
387 inner: Arc<SystemSocket>,
388}
389
390impl Socket {
391 pub fn new(socket_option: SocketOption) -> io::Result<Socket> {
393 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
394 SystemSocket::new(
395 socket_option.ip_version.to_domain(),
396 socket_option.socket_type.to_type(),
397 Some(to_socket_protocol(protocol)),
398 )?
399 } else {
400 SystemSocket::new(
401 socket_option.ip_version.to_domain(),
402 socket_option.socket_type.to_type(),
403 None,
404 )?
405 };
406 if socket_option.non_blocking {
407 socket.set_nonblocking(true)?;
408 }
409 Ok(Socket {
410 inner: Arc::new(socket),
411 })
412 }
413 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
415 let target: SockAddr = SockAddr::from(target);
416 match self.inner.send_to(buf, &target) {
417 Ok(n) => Ok(n),
418 Err(e) => Err(e),
419 }
420 }
421 pub fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
423 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
424 match self.inner.recv(recv_buf) {
425 Ok(result) => Ok(result),
426 Err(e) => Err(e),
427 }
428 }
429 pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
431 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
432 match self.inner.recv_from(recv_buf) {
433 Ok(result) => {
434 let (n, addr) = result;
435 match addr.as_socket() {
436 Some(addr) => return Ok((n, addr)),
437 None => {
438 return Err(io::Error::new(
439 io::ErrorKind::Other,
440 "Invalid socket address",
441 ))
442 }
443 }
444 }
445 Err(e) => Err(e),
446 }
447 }
448 pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
451 match self.inner.send(buf) {
452 Ok(n) => Ok(n),
453 Err(e) => Err(e),
454 }
455 }
456 pub fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
459 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
460 match self.inner.recv(recv_buf) {
461 Ok(result) => Ok(result),
462 Err(e) => Err(e),
463 }
464 }
465 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
467 let addr: SockAddr = SockAddr::from(addr);
468 self.inner.bind(&addr)
469 }
470 pub fn set_receive_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
472 self.inner.set_read_timeout(timeout)
473 }
474 pub fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
476 match ip_version {
477 IpVersion::V4 => self.inner.set_ttl(ttl),
478 IpVersion::V6 => self.inner.set_unicast_hops_v6(ttl),
479 }
480 }
481 pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
483 let addr: SockAddr = SockAddr::from(addr);
484 self.inner.connect(&addr)
485 }
486 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
488 self.inner.shutdown(how)
489 }
490 pub fn listen(&self, backlog: i32) -> io::Result<()> {
492 self.inner.listen(backlog)
493 }
494 pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
496 match self.inner.accept() {
497 Ok((socket, addr)) => Ok((
498 Socket {
499 inner: Arc::new(socket),
500 },
501 addr.as_socket().unwrap(),
502 )),
503 Err(e) => Err(e),
504 }
505 }
506 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
508 match self.inner.peer_addr() {
509 Ok(addr) => Ok(addr.as_socket().unwrap()),
510 Err(e) => Err(e),
511 }
512 }
513 pub fn local_addr(&self) -> io::Result<SocketAddr> {
515 match self.inner.local_addr() {
516 Ok(addr) => Ok(addr.as_socket().unwrap()),
517 Err(e) => Err(e),
518 }
519 }
520 pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
523 let addr: SockAddr = SockAddr::from(*addr);
524 self.inner.connect_timeout(&addr, timeout)
525 }
526 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
527 self.inner.set_nonblocking(nonblocking)
528 }
529 pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
533 self.inner.set_broadcast(broadcast)
534 }
535 pub fn get_error(&self) -> io::Result<Option<io::Error>> {
537 self.inner.take_error()
538 }
539 pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
543 self.inner.set_keepalive(keepalive)
544 }
545 pub fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
549 self.inner.set_recv_buffer_size(size)
550 }
551 pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
555 self.inner.set_reuse_address(reuse)
556 }
557 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
561 self.inner.set_send_buffer_size(size)
562 }
563 pub fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
567 self.inner.set_write_timeout(duration)
568 }
569 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
573 self.inner.set_nodelay(nodelay)
574 }
575}
576
577fn to_socket_protocol(protocol: IpNextLevelProtocol) -> socket2::Protocol {
578 match protocol {
579 IpNextLevelProtocol::Tcp => socket2::Protocol::TCP,
580 IpNextLevelProtocol::Udp => socket2::Protocol::UDP,
581 IpNextLevelProtocol::Icmp => socket2::Protocol::ICMPV4,
582 IpNextLevelProtocol::Icmpv6 => socket2::Protocol::ICMPV6,
583 _ => socket2::Protocol::TCP,
584 }
585}