1use std::borrow::Borrow;
2use std::io;
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5
6use bytes::{BufMut, Bytes, BytesMut};
7use futures::Sink;
8use futures_core::Stream;
9
10#[cfg(all(target_os = "linux", not(target_env = "ohos")))]
11use crate::platform::offload::VirtioNetHdr;
12use crate::AsyncDevice;
13#[cfg(all(target_os = "linux", not(target_env = "ohos")))]
14use crate::{GROTable, IDEAL_BATCH_SIZE, VIRTIO_NET_HDR_LEN};
15
16pub trait Decoder {
17 type Item;
19
20 type Error: From<io::Error>;
22 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>;
23 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
24 match self.decode(buf)? {
25 Some(frame) => Ok(Some(frame)),
26 None => {
27 if buf.is_empty() {
28 Ok(None)
29 } else {
30 Err(io::Error::other("bytes remaining on stream").into())
31 }
32 }
33 }
34 }
35}
36
37impl<T: Decoder> Decoder for &mut T {
38 type Item = T::Item;
39 type Error = T::Error;
40
41 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
42 T::decode(self, src)
43 }
44}
45
46pub trait Encoder<Item> {
47 type Error: From<io::Error>;
49
50 fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>;
52}
53
54impl<T: Encoder<Item>, Item> Encoder<Item> for &mut T {
55 type Error = T::Error;
56
57 fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
58 T::encode(self, item, dst)
59 }
60}
61
62pub struct DeviceFramed<C, T = AsyncDevice> {
124 dev: T,
125 codec: C,
126 r_state: ReadState,
127 w_state: WriteState,
128}
129impl<C, T> Unpin for DeviceFramed<C, T> {}
130impl<C, T> Stream for DeviceFramed<C, T>
131where
132 T: Borrow<AsyncDevice>,
133 C: Decoder,
134{
135 type Item = Result<C::Item, C::Error>;
136 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137 let pin = self.get_mut();
138 DeviceFramedReadInner::new(&pin.dev, &mut pin.codec, &mut pin.r_state).poll_next(cx)
139 }
140}
141impl<I, C, T> Sink<I> for DeviceFramed<C, T>
142where
143 T: Borrow<AsyncDevice>,
144 C: Encoder<I>,
145{
146 type Error = C::Error;
147
148 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
149 let pin = self.get_mut();
150 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.w_state).poll_ready(cx)
151 }
152
153 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
154 let pin = self.get_mut();
155 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.w_state).start_send(item)
156 }
157
158 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159 let pin = self.get_mut();
160 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.w_state).poll_flush(cx)
161 }
162
163 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
164 let pin = self.get_mut();
165 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.w_state).poll_close(cx)
166 }
167}
168impl<C, T> DeviceFramed<C, T>
169where
170 T: Borrow<AsyncDevice>,
171{
172 pub fn new(dev: T, codec: C) -> DeviceFramed<C, T> {
174 let buffer_size = compute_buffer_size(&dev);
175 DeviceFramed {
176 r_state: ReadState::new(buffer_size, dev.borrow()),
177 w_state: WriteState::new(buffer_size, dev.borrow()),
178 dev,
179 codec,
180 }
181 }
182 pub fn read_buffer_size(&self) -> usize {
183 self.r_state.read_buffer_size()
184 }
185 pub fn write_buffer_size(&self) -> usize {
186 self.w_state.write_buffer_size()
187 }
188
189 pub fn set_read_buffer_size(&mut self, read_buffer_size: usize) {
193 self.r_state.set_read_buffer_size(read_buffer_size);
194 }
195 pub fn set_write_buffer_size(&mut self, write_buffer_size: usize) {
206 self.w_state.set_write_buffer_size(write_buffer_size);
207 }
208 pub fn read_buffer(&self) -> &BytesMut {
210 &self.r_state.rd
211 }
212
213 pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
215 &mut self.r_state.rd
216 }
217 pub fn into_inner(self) -> T {
219 self.dev
220 }
221}
222
223impl<C, T> DeviceFramed<C, T>
224where
225 T: Borrow<AsyncDevice> + Clone,
226 C: Clone,
227{
228 pub fn split(self) -> (DeviceFramedRead<C, T>, DeviceFramedWrite<C, T>) {
246 let dev = self.dev;
247 let codec = self.codec;
248 (
249 DeviceFramedRead::new(dev.clone(), codec.clone()),
250 DeviceFramedWrite::new(dev, codec),
251 )
252 }
253}
254
255pub struct DeviceFramedRead<C, T = AsyncDevice> {
298 dev: T,
299 codec: C,
300 state: ReadState,
301}
302impl<C, T> DeviceFramedRead<C, T>
303where
304 T: Borrow<AsyncDevice>,
305{
306 pub fn new(dev: T, codec: C) -> DeviceFramedRead<C, T> {
328 let buffer_size = compute_buffer_size(&dev);
329 DeviceFramedRead {
330 state: ReadState::new(buffer_size, dev.borrow()),
331 dev,
332 codec,
333 }
334 }
335 pub fn read_buffer_size(&self) -> usize {
336 self.state.read_buffer_size()
337 }
338 pub fn set_read_buffer_size(&mut self, read_buffer_size: usize) {
342 self.state.set_read_buffer_size(read_buffer_size);
343 }
344 pub fn into_inner(self) -> T {
346 self.dev
347 }
348}
349impl<C, T> Unpin for DeviceFramedRead<C, T> {}
350impl<C, T> Stream for DeviceFramedRead<C, T>
351where
352 T: Borrow<AsyncDevice>,
353 C: Decoder,
354{
355 type Item = Result<C::Item, C::Error>;
356 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357 let pin = self.get_mut();
358 DeviceFramedReadInner::new(&pin.dev, &mut pin.codec, &mut pin.state).poll_next(cx)
359 }
360}
361
362pub struct DeviceFramedWrite<C, T = AsyncDevice> {
403 dev: T,
404 codec: C,
405 state: WriteState,
406}
407impl<C, T> DeviceFramedWrite<C, T>
408where
409 T: Borrow<AsyncDevice>,
410{
411 pub fn new(dev: T, codec: C) -> DeviceFramedWrite<C, T> {
433 let buffer_size = compute_buffer_size(&dev);
434 DeviceFramedWrite {
435 state: WriteState::new(buffer_size, dev.borrow()),
436 dev,
437 codec,
438 }
439 }
440 pub fn write_buffer_size(&self) -> usize {
441 self.state.send_buffer_size
442 }
443 pub fn set_write_buffer_size(&mut self, write_buffer_size: usize) {
454 self.state.set_write_buffer_size(write_buffer_size);
455 }
456
457 pub fn into_inner(self) -> T {
459 self.dev
460 }
461}
462
463impl<C, T> Unpin for DeviceFramedWrite<C, T> {}
464impl<I, C, T> Sink<I> for DeviceFramedWrite<C, T>
465where
466 T: Borrow<AsyncDevice>,
467 C: Encoder<I>,
468{
469 type Error = C::Error;
470
471 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
472 let pin = self.get_mut();
473 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.state).poll_ready(cx)
474 }
475
476 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
477 let pin = self.get_mut();
478 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.state).start_send(item)
479 }
480
481 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
482 let pin = self.get_mut();
483 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.state).poll_flush(cx)
484 }
485
486 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
487 let pin = self.get_mut();
488 DeviceFramedWriteInner::new(&pin.dev, &mut pin.codec, &mut pin.state).poll_close(cx)
489 }
490}
491fn compute_buffer_size<T: Borrow<AsyncDevice>>(_dev: &T) -> usize {
492 #[cfg(any(
493 target_os = "windows",
494 all(target_os = "linux", not(target_env = "ohos")),
495 target_os = "macos",
496 target_os = "freebsd",
497 target_os = "openbsd",
498 ))]
499 let mtu = _dev.borrow().mtu().map(|m| m as usize).unwrap_or(4096);
500
501 #[cfg(not(any(
502 target_os = "windows",
503 all(target_os = "linux", not(target_env = "ohos")),
504 target_os = "macos",
505 target_os = "freebsd",
506 target_os = "openbsd",
507 )))]
508 let mtu = 4096usize;
509
510 #[cfg(windows)]
511 {
512 let mtu_v6 = _dev.borrow().mtu_v6().map(|m| m as usize).unwrap_or(4096);
513 mtu.max(mtu_v6)
514 }
515 #[cfg(not(windows))]
516 mtu
517}
518struct ReadState {
519 recv_buffer_size: usize,
520 rd: BytesMut,
521 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
522 packet_splitter: Option<PacketSplitter>,
523}
524impl ReadState {
525 pub(crate) fn new(recv_buffer_size: usize, _device: &AsyncDevice) -> ReadState {
526 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
527 let packet_splitter = if _device.tcp_gso() {
528 Some(PacketSplitter::new(recv_buffer_size))
529 } else {
530 None
531 };
532
533 ReadState {
534 recv_buffer_size,
535 rd: BytesMut::with_capacity(recv_buffer_size),
536 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
537 packet_splitter,
538 }
539 }
540
541 pub(crate) fn read_buffer_size(&self) -> usize {
542 self.recv_buffer_size
543 }
544
545 pub(crate) fn set_read_buffer_size(&mut self, read_buffer_size: usize) {
546 self.recv_buffer_size = read_buffer_size;
547 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
548 if let Some(packet_splitter) = &mut self.packet_splitter {
549 packet_splitter.set_recv_buffer_size(read_buffer_size);
550 }
551 }
552}
553struct WriteState {
554 send_buffer_size: usize,
555 wr: BytesMut,
556 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
557 packet_arena: Option<PacketArena>,
558}
559impl WriteState {
560 pub(crate) fn new(send_buffer_size: usize, _device: &AsyncDevice) -> WriteState {
561 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
562 let packet_arena = if _device.tcp_gso() {
563 Some(PacketArena::new())
564 } else {
565 None
566 };
567
568 WriteState {
569 send_buffer_size,
570 wr: BytesMut::new(),
571 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
572 packet_arena,
573 }
574 }
575 pub(crate) fn write_buffer_size(&self) -> usize {
576 self.send_buffer_size
577 }
578
579 pub(crate) fn set_write_buffer_size(&mut self, write_buffer_size: usize) {
580 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
581 if self.packet_arena.is_some() {
582 return;
584 }
585 if self.send_buffer_size >= write_buffer_size {
586 return;
587 }
588 self.send_buffer_size = write_buffer_size;
589 }
590}
591
592#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
593pub struct BytesCodec(());
594impl BytesCodec {
595 pub fn new() -> BytesCodec {
597 BytesCodec(())
598 }
599}
600impl Decoder for BytesCodec {
601 type Item = BytesMut;
602 type Error = io::Error;
603
604 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
605 if !buf.is_empty() {
606 let rs = buf.clone();
607 buf.clear();
608 Ok(Some(rs))
609 } else {
610 Ok(None)
611 }
612 }
613}
614
615impl Encoder<Bytes> for BytesCodec {
616 type Error = io::Error;
617
618 fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
619 buf.reserve(data.len());
620 buf.put(data);
621 Ok(())
622 }
623}
624
625impl Encoder<BytesMut> for BytesCodec {
626 type Error = io::Error;
627
628 fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
629 buf.reserve(data.len());
630 buf.put(data);
631 Ok(())
632 }
633}
634
635#[cfg(all(target_os = "linux", not(target_env = "ohos")))]
636struct PacketSplitter {
637 bufs: Vec<BytesMut>,
638 sizes: Vec<usize>,
639 recv_index: usize,
640 recv_num: usize,
641 recv_buffer_size: usize,
642}
643#[cfg(all(target_os = "linux", not(target_env = "ohos")))]
644impl PacketSplitter {
645 fn new(recv_buffer_size: usize) -> PacketSplitter {
646 let bufs = vec![BytesMut::zeroed(recv_buffer_size); IDEAL_BATCH_SIZE];
647 let sizes = vec![0usize; IDEAL_BATCH_SIZE];
648 Self {
649 bufs,
650 sizes,
651 recv_index: 0,
652 recv_num: 0,
653 recv_buffer_size,
654 }
655 }
656 fn handle(&mut self, dev: &AsyncDevice, input: &mut [u8]) -> io::Result<()> {
657 if input.len() <= VIRTIO_NET_HDR_LEN {
658 Err(io::Error::other(format!(
659 "length of packet ({}) <= VIRTIO_NET_HDR_LEN ({VIRTIO_NET_HDR_LEN})",
660 input.len(),
661 )))?
662 }
663 for buf in &mut self.bufs {
664 buf.resize(self.recv_buffer_size, 0);
665 }
666 let hdr = VirtioNetHdr::decode(&input[..VIRTIO_NET_HDR_LEN])?;
667 let num = dev.handle_virtio_read(
668 hdr,
669 &mut input[VIRTIO_NET_HDR_LEN..],
670 &mut self.bufs,
671 &mut self.sizes,
672 0,
673 )?;
674
675 for i in 0..num {
676 self.bufs[i].truncate(self.sizes[i]);
677 }
678 self.recv_num = num;
679 self.recv_index = 0;
680 Ok(())
681 }
682 fn next(&mut self) -> Option<&mut BytesMut> {
683 if self.recv_index >= self.recv_num {
684 None
685 } else {
686 let buf = &mut self.bufs[self.recv_index];
687 self.recv_index += 1;
688 Some(buf)
689 }
690 }
691 fn set_recv_buffer_size(&mut self, recv_buffer_size: usize) {
692 self.recv_buffer_size = recv_buffer_size;
693 }
694}
695#[cfg(all(target_os = "linux", not(target_env = "ohos")))]
696struct PacketArena {
697 gro_table: GROTable,
698 offset: usize,
699 bufs: Vec<BytesMut>,
700 send_index: usize,
701}
702#[cfg(all(target_os = "linux", not(target_env = "ohos")))]
703impl PacketArena {
704 fn new() -> PacketArena {
705 Self {
706 gro_table: Default::default(),
707 offset: 0,
708 bufs: Vec::with_capacity(IDEAL_BATCH_SIZE),
709 send_index: 0,
710 }
711 }
712 fn get(&mut self) -> &mut BytesMut {
713 if self.offset < self.bufs.len() {
714 let buf = &mut self.bufs[self.offset];
715 self.offset += 1;
716 buf.clear();
717 buf.reserve(VIRTIO_NET_HDR_LEN + 65536);
718 return buf;
719 }
720 assert_eq!(self.offset, self.bufs.len());
721 self.bufs
722 .push(BytesMut::with_capacity(VIRTIO_NET_HDR_LEN + 65536));
723 let idx = self.offset;
724 self.offset += 1;
725 &mut self.bufs[idx]
726 }
727 fn handle(&mut self, dev: &AsyncDevice) -> io::Result<()> {
728 if self.offset == 0 {
729 return Ok(());
730 }
731 if !self.gro_table.to_write.is_empty() {
732 return Ok(());
733 }
734 crate::platform::offload::handle_gro(
735 &mut self.bufs[..self.offset],
736 VIRTIO_NET_HDR_LEN,
737 &mut self.gro_table.tcp_gro_table,
738 &mut self.gro_table.udp_gro_table,
739 dev.udp_gso,
740 &mut self.gro_table.to_write,
741 )
742 }
743 fn poll_send_bufs(&mut self, cx: &mut Context<'_>, dev: &AsyncDevice) -> Poll<io::Result<()>> {
744 if self.offset == 0 {
745 return Poll::Ready(Ok(()));
746 }
747 let gro_table = &mut self.gro_table;
748 let bufs = &self.bufs[..self.offset];
749 for buf_idx in &gro_table.to_write[self.send_index..] {
750 let rs = dev.poll_send(cx, &bufs[*buf_idx]);
751 match rs {
752 Poll::Ready(Ok(_)) => {
753 self.send_index += 1;
754 }
755 Poll::Ready(Err(e)) => {
756 self.send_index += 1;
757 if self.send_index >= gro_table.to_write.len() {
758 self.reset();
759 }
760 return Poll::Ready(Err(e));
761 }
762 Poll::Pending => {
763 return Poll::Pending;
764 }
765 }
766 }
767 self.reset();
768 Poll::Ready(Ok(()))
769 }
770 fn reset(&mut self) {
771 self.gro_table.reset();
772 for buf in self.bufs[..self.offset].iter_mut() {
773 buf.clear();
774 }
775 self.offset = 0;
776 self.send_index = 0;
777 }
778 fn is_idle(&self) -> bool {
779 IDEAL_BATCH_SIZE > self.offset && self.gro_table.to_write.is_empty()
780 }
781}
782struct DeviceFramedReadInner<'a, C, T = AsyncDevice> {
783 dev: &'a T,
784 codec: &'a mut C,
785 state: &'a mut ReadState,
786}
787impl<'a, C, T> DeviceFramedReadInner<'a, C, T>
788where
789 T: Borrow<AsyncDevice>,
790 C: Decoder,
791{
792 fn new(
793 dev: &'a T,
794 codec: &'a mut C,
795 state: &'a mut ReadState,
796 ) -> DeviceFramedReadInner<'a, C, T> {
797 DeviceFramedReadInner { dev, codec, state }
798 }
799
800 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<C::Item, C::Error>>> {
801 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
802 if let Some(packet_splitter) = &mut self.state.packet_splitter {
803 if let Some(buf) = packet_splitter.next() {
804 if let Some(frame) = self.codec.decode_eof(buf)? {
805 return Poll::Ready(Some(Ok(frame)));
806 }
807 }
808 }
809
810 self.state.rd.clear();
811 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
812 if self.state.packet_splitter.is_some() {
813 self.state.rd.reserve(VIRTIO_NET_HDR_LEN + 65536);
814 }
815 self.state.rd.reserve(self.state.recv_buffer_size);
816 let buf = unsafe { &mut *(self.state.rd.chunk_mut() as *mut _ as *mut [u8]) };
817
818 let len = ready!(self.dev.borrow().poll_recv(cx, buf))?;
819 unsafe { self.state.rd.advance_mut(len) };
820
821 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
822 if let Some(packet_splitter) = &mut self.state.packet_splitter {
823 packet_splitter.handle(self.dev.borrow(), &mut self.state.rd)?;
824 if let Some(buf) = packet_splitter.next() {
825 if let Some(frame) = self.codec.decode_eof(buf)? {
826 return Poll::Ready(Some(Ok(frame)));
827 }
828 }
829 return Poll::Ready(None);
830 }
831 if let Some(frame) = self.codec.decode_eof(&mut self.state.rd)? {
832 return Poll::Ready(Some(Ok(frame)));
833 }
834 Poll::Ready(None)
835 }
836}
837struct DeviceFramedWriteInner<'a, C, T = AsyncDevice> {
838 dev: &'a T,
839 codec: &'a mut C,
840 state: &'a mut WriteState,
841}
842impl<'a, C, T> DeviceFramedWriteInner<'a, C, T>
843where
844 T: Borrow<AsyncDevice>,
845{
846 fn new(
847 dev: &'a T,
848 codec: &'a mut C,
849 state: &'a mut WriteState,
850 ) -> DeviceFramedWriteInner<'a, C, T> {
851 DeviceFramedWriteInner { dev, codec, state }
852 }
853
854 fn poll_ready<I>(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), C::Error>>
855 where
856 C: Encoder<I>,
857 {
858 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
859 if let Some(packet_arena) = &self.state.packet_arena {
860 if packet_arena.is_idle() {
861 return Poll::Ready(Ok(()));
862 }
863 }
864 ready!(self.poll_flush(cx))?;
865 Poll::Ready(Ok(()))
866 }
867
868 fn start_send<I>(&mut self, item: I) -> Result<(), C::Error>
869 where
870 C: Encoder<I>,
871 {
872 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
873 if let Some(packet_arena) = &mut self.state.packet_arena {
874 let buf = packet_arena.get();
875 buf.resize(VIRTIO_NET_HDR_LEN, 0);
876 self.codec.encode(item, buf)?;
877 return Ok(());
878 }
879 let buf = &mut self.state.wr;
880 buf.clear();
881 buf.reserve(self.state.send_buffer_size);
882 self.codec.encode(item, buf)?;
883 Ok(())
884 }
885
886 fn poll_flush<I>(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), C::Error>>
887 where
888 C: Encoder<I>,
889 {
890 let dev = self.dev.borrow();
891
892 #[cfg(all(target_os = "linux", not(target_env = "ohos")))]
893 if let Some(packet_arena) = &mut self.state.packet_arena {
894 packet_arena.handle(dev)?;
895 ready!(packet_arena.poll_send_bufs(cx, dev))?;
896 return Poll::Ready(Ok(()));
897 }
898
899 if !self.state.wr.is_empty() {
901 let rs = ready!(dev.poll_send(cx, &self.state.wr));
902 self.state.wr.clear();
903 rs?;
904 }
905 Poll::Ready(Ok(()))
906 }
907
908 fn poll_close<I>(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), C::Error>>
909 where
910 C: Encoder<I>,
911 {
912 ready!(self.poll_flush(cx))?;
913 Poll::Ready(Ok(()))
914 }
915}