1use crate::platform::linux::checksum::{checksum, pseudo_header_checksum_no_fold};
3use byteorder::{BigEndian, ByteOrder};
4use bytes::BytesMut;
5use libc::{IPPROTO_TCP, IPPROTO_UDP};
6use std::collections::HashMap;
7use std::io;
8
9pub const VIRTIO_NET_HDR_GSO_NONE: u8 = 0;
11pub const VIRTIO_NET_HDR_F_NEEDS_CSUM: u8 = 1;
12pub const VIRTIO_NET_HDR_GSO_TCPV4: u8 = 1;
13pub const VIRTIO_NET_HDR_GSO_TCPV6: u8 = 4;
14pub const VIRTIO_NET_HDR_GSO_UDP_L4: u8 = 5;
15
16pub const IDEAL_BATCH_SIZE: usize = 128;
20
21const TCP_FLAGS_OFFSET: usize = 13;
22
23const TCP_FLAG_FIN: u8 = 0x01;
24const TCP_FLAG_PSH: u8 = 0x08;
25const TCP_FLAG_ACK: u8 = 0x10;
26
27#[repr(C)]
32#[derive(Debug, Clone, Copy, Default)]
33pub struct VirtioNetHdr {
34 pub flags: u8,
38 pub gso_type: u8,
45 pub hdr_len: u16,
47 pub gso_size: u16,
49 pub csum_start: u16,
51 pub csum_offset: u16,
52}
53
54impl VirtioNetHdr {
55 pub fn decode(buf: &[u8]) -> io::Result<VirtioNetHdr> {
56 if buf.len() < VIRTIO_NET_HDR_LEN {
57 return Err(io::Error::new(io::ErrorKind::InvalidInput, "too short"));
58 }
59 let mut hdr = std::mem::MaybeUninit::<VirtioNetHdr>::uninit();
60 unsafe {
61 std::ptr::copy_nonoverlapping(
64 buf.as_ptr(),
65 hdr.as_mut_ptr() as *mut _,
66 std::mem::size_of::<VirtioNetHdr>(),
67 );
68 Ok(hdr.assume_init())
69 }
70 }
71 pub fn encode(&self, buf: &mut [u8]) -> io::Result<()> {
72 if buf.len() < VIRTIO_NET_HDR_LEN {
73 return Err(io::Error::new(io::ErrorKind::InvalidInput, "too short"));
74 }
75 unsafe {
76 let hdr_ptr = self as *const VirtioNetHdr as *const u8;
77 std::ptr::copy_nonoverlapping(hdr_ptr, buf.as_mut_ptr(), VIRTIO_NET_HDR_LEN);
78 Ok(())
79 }
80 }
81}
82
83pub const VIRTIO_NET_HDR_LEN: usize = std::mem::size_of::<VirtioNetHdr>();
86
87#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
89pub struct TcpFlowKey {
90 src_addr: [u8; 16],
91 dst_addr: [u8; 16],
92 src_port: u16,
93 dst_port: u16,
94 rx_ack: u32, is_v6: bool,
96}
97
98pub struct TcpGROTable {
100 items_by_flow: HashMap<TcpFlowKey, Vec<TcpGROItem>>,
101 items_pool: Vec<Vec<TcpGROItem>>,
102}
103
104impl Default for TcpGROTable {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl TcpGROTable {
111 fn new() -> Self {
112 let mut items_pool = Vec::with_capacity(IDEAL_BATCH_SIZE);
113 for _ in 0..IDEAL_BATCH_SIZE {
114 items_pool.push(Vec::with_capacity(IDEAL_BATCH_SIZE));
115 }
116 TcpGROTable {
117 items_by_flow: HashMap::with_capacity(IDEAL_BATCH_SIZE),
118 items_pool,
119 }
120 }
121}
122
123impl TcpFlowKey {
124 fn new(pkt: &[u8], src_addr_offset: usize, dst_addr_offset: usize, tcph_offset: usize) -> Self {
125 let mut key = TcpFlowKey {
126 src_addr: [0; 16],
127 dst_addr: [0; 16],
128 src_port: 0,
129 dst_port: 0,
130 rx_ack: 0,
131 is_v6: false,
132 };
133
134 let addr_size = dst_addr_offset - src_addr_offset;
135 key.src_addr[..addr_size].copy_from_slice(&pkt[src_addr_offset..dst_addr_offset]);
136 key.dst_addr[..addr_size]
137 .copy_from_slice(&pkt[dst_addr_offset..dst_addr_offset + addr_size]);
138 key.src_port = BigEndian::read_u16(&pkt[tcph_offset..]);
139 key.dst_port = BigEndian::read_u16(&pkt[tcph_offset + 2..]);
140 key.rx_ack = BigEndian::read_u32(&pkt[tcph_offset + 8..]);
141 key.is_v6 = addr_size == 16;
142 key
143 }
144}
145
146impl TcpGROTable {
147 fn lookup_or_insert(
151 &mut self,
152 pkt: &[u8],
153 src_addr_offset: usize,
154 dst_addr_offset: usize,
155 tcph_offset: usize,
156 tcph_len: usize,
157 bufs_index: usize,
158 ) -> Option<&mut Vec<TcpGROItem>> {
159 let key = TcpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, tcph_offset);
160 if self.items_by_flow.contains_key(&key) {
161 return self.items_by_flow.get_mut(&key);
162 }
163 self.insert(
165 pkt,
166 src_addr_offset,
167 dst_addr_offset,
168 tcph_offset,
169 tcph_len,
170 bufs_index,
171 );
172 None
173 }
174 fn insert(
176 &mut self,
177 pkt: &[u8],
178 src_addr_offset: usize,
179 dst_addr_offset: usize,
180 tcph_offset: usize,
181 tcph_len: usize,
182 bufs_index: usize,
183 ) {
184 let key = TcpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, tcph_offset);
185 let item = TcpGROItem {
186 key,
187 bufs_index: bufs_index as u16,
188 num_merged: 0,
189 gso_size: pkt[tcph_offset + tcph_len..].len() as u16,
190 iph_len: tcph_offset as u8,
191 tcph_len: tcph_len as u8,
192 sent_seq: BigEndian::read_u32(&pkt[tcph_offset + 4..tcph_offset + 8]),
193 psh_set: pkt[tcph_offset + TCP_FLAGS_OFFSET] & TCP_FLAG_PSH != 0,
194 };
195
196 let items = self
197 .items_by_flow
198 .entry(key)
199 .or_insert_with(|| self.items_pool.pop().unwrap_or_default());
200 items.push(item);
201 }
202}
203#[derive(Debug, Clone, Copy)]
217pub struct TcpGROItem {
218 key: TcpFlowKey,
219 sent_seq: u32, bufs_index: u16, num_merged: u16, gso_size: u16, iph_len: u8, tcph_len: u8, psh_set: bool, }
227
228impl TcpGROTable {
234 fn reset(&mut self) {
235 for (_key, mut items) in self.items_by_flow.drain() {
236 items.clear();
237 self.items_pool.push(items);
238 }
239 }
240}
241
242#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
244pub struct UdpFlowKey {
245 src_addr: [u8; 16], dst_addr: [u8; 16], src_port: u16, dst_port: u16, is_v6: bool, }
251
252pub struct UdpGROTable {
254 items_by_flow: HashMap<UdpFlowKey, Vec<UdpGROItem>>,
255 items_pool: Vec<Vec<UdpGROItem>>,
256}
257
258impl Default for UdpGROTable {
259 fn default() -> Self {
260 UdpGROTable::new()
261 }
262}
263
264impl UdpGROTable {
265 pub fn new() -> Self {
266 let mut items_pool = Vec::with_capacity(IDEAL_BATCH_SIZE);
267 for _ in 0..IDEAL_BATCH_SIZE {
268 items_pool.push(Vec::with_capacity(IDEAL_BATCH_SIZE));
269 }
270 UdpGROTable {
271 items_by_flow: HashMap::with_capacity(IDEAL_BATCH_SIZE),
272 items_pool,
273 }
274 }
275}
276
277impl UdpFlowKey {
278 pub fn new(
279 pkt: &[u8],
280 src_addr_offset: usize,
281 dst_addr_offset: usize,
282 udph_offset: usize,
283 ) -> UdpFlowKey {
284 let mut key = UdpFlowKey {
285 src_addr: [0; 16],
286 dst_addr: [0; 16],
287 src_port: 0,
288 dst_port: 0,
289 is_v6: false,
290 };
291 let addr_size = dst_addr_offset - src_addr_offset;
292 key.src_addr[..addr_size].copy_from_slice(&pkt[src_addr_offset..dst_addr_offset]);
293 key.dst_addr[..addr_size]
294 .copy_from_slice(&pkt[dst_addr_offset..dst_addr_offset + addr_size]);
295 key.src_port = BigEndian::read_u16(&pkt[udph_offset..]);
296 key.dst_port = BigEndian::read_u16(&pkt[udph_offset + 2..]);
297 key.is_v6 = addr_size == 16;
298 key
299 }
300}
301
302impl UdpGROTable {
303 fn lookup_or_insert(
307 &mut self,
308 pkt: &[u8],
309 src_addr_offset: usize,
310 dst_addr_offset: usize,
311 udph_offset: usize,
312 bufs_index: usize,
313 ) -> Option<&mut Vec<UdpGROItem>> {
314 let key = UdpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, udph_offset);
315 if self.items_by_flow.contains_key(&key) {
316 self.items_by_flow.get_mut(&key)
317 } else {
318 self.insert(
320 pkt,
321 src_addr_offset,
322 dst_addr_offset,
323 udph_offset,
324 bufs_index,
325 false,
326 );
327 None
328 }
329 }
330 fn insert(
332 &mut self,
333 pkt: &[u8],
334 src_addr_offset: usize,
335 dst_addr_offset: usize,
336 udph_offset: usize,
337 bufs_index: usize,
338 c_sum_known_invalid: bool,
339 ) {
340 let key = UdpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, udph_offset);
341 let item = UdpGROItem {
342 key,
343 bufs_index: bufs_index as u16,
344 num_merged: 0,
345 gso_size: (pkt.len() - (udph_offset + UDP_H_LEN)) as u16,
346 iph_len: udph_offset as u8,
347 c_sum_known_invalid,
348 };
349 let items = self
350 .items_by_flow
351 .entry(key)
352 .or_insert_with(|| self.items_pool.pop().unwrap_or_default());
353 items.push(item);
354 }
355}
356#[derive(Debug, Clone, Copy)]
364pub struct UdpGROItem {
365 key: UdpFlowKey, bufs_index: u16, num_merged: u16, gso_size: u16, iph_len: u8, c_sum_known_invalid: bool, }
372impl UdpGROTable {
379 fn reset(&mut self) {
380 for (_key, mut items) in self.items_by_flow.drain() {
381 items.clear();
382 self.items_pool.push(items);
383 }
384 }
385}
386
387#[derive(Copy, Clone, Eq, PartialEq)]
390enum CanCoalesce {
391 Prepend,
392 Unavailable,
393 Append,
394}
395
396fn ip_headers_can_coalesce(pkt_a: &[u8], pkt_b: &[u8]) -> bool {
400 if pkt_a.len() < 9 || pkt_b.len() < 9 {
401 return false;
402 }
403
404 if pkt_a[0] >> 4 == 6 {
405 if pkt_a[0] != pkt_b[0] || pkt_a[1] >> 4 != pkt_b[1] >> 4 {
406 return false;
408 }
409 if pkt_a[7] != pkt_b[7] {
410 return false;
412 }
413 } else {
414 if pkt_a[1] != pkt_b[1] {
415 return false;
417 }
418 if pkt_a[6] >> 5 != pkt_b[6] >> 5 {
419 return false;
422 }
423 if pkt_a[8] != pkt_b[8] {
424 return false;
426 }
427 }
428
429 true
430}
431
432fn udp_packets_can_coalesce<B: ExpandBuffer>(
437 pkt: &[u8],
438 iph_len: u8,
439 gso_size: u16,
440 item: &UdpGROItem,
441 bufs: &[B],
442 bufs_offset: usize,
443) -> CanCoalesce {
444 let pkt_target = &bufs[item.bufs_index as usize].as_ref()[bufs_offset..];
445 if !ip_headers_can_coalesce(pkt, pkt_target) {
446 return CanCoalesce::Unavailable;
447 }
448 if (pkt_target[(iph_len as usize + UDP_H_LEN)..].len()) % (item.gso_size as usize) != 0 {
449 return CanCoalesce::Unavailable;
452 }
453 if gso_size > item.gso_size {
454 return CanCoalesce::Unavailable;
456 }
457 CanCoalesce::Append
458}
459
460#[allow(clippy::too_many_arguments)]
464fn tcp_packets_can_coalesce<B: ExpandBuffer>(
465 pkt: &[u8],
466 iph_len: u8,
467 tcph_len: u8,
468 seq: u32,
469 psh_set: bool,
470 gso_size: u16,
471 item: &TcpGROItem,
472 bufs: &[B],
473 bufs_offset: usize,
474) -> CanCoalesce {
475 let pkt_target = &bufs[item.bufs_index as usize].as_ref()[bufs_offset..];
476
477 if tcph_len != item.tcph_len {
478 return CanCoalesce::Unavailable;
480 }
481
482 if tcph_len > 20
483 && pkt[iph_len as usize + 20..iph_len as usize + tcph_len as usize]
484 != pkt_target[item.iph_len as usize + 20..item.iph_len as usize + tcph_len as usize]
485 {
486 return CanCoalesce::Unavailable;
488 }
489
490 if !ip_headers_can_coalesce(pkt, pkt_target) {
491 return CanCoalesce::Unavailable;
492 }
493
494 let mut lhs_len = item.gso_size as usize;
496 lhs_len += (item.num_merged as usize) * (item.gso_size as usize);
497
498 if seq == item.sent_seq.wrapping_add(lhs_len as u32) {
499 if item.psh_set {
501 return CanCoalesce::Unavailable;
504 }
505
506 if pkt_target[iph_len as usize + tcph_len as usize..].len() % item.gso_size as usize != 0 {
507 return CanCoalesce::Unavailable;
510 }
511
512 if gso_size > item.gso_size {
513 return CanCoalesce::Unavailable;
515 }
516
517 return CanCoalesce::Append;
518 } else if seq.wrapping_add(gso_size as u32) == item.sent_seq {
519 if psh_set {
521 return CanCoalesce::Unavailable;
524 }
525
526 if gso_size < item.gso_size {
527 return CanCoalesce::Unavailable;
529 }
530
531 if gso_size > item.gso_size && item.num_merged > 0 {
532 return CanCoalesce::Unavailable;
535 }
536
537 return CanCoalesce::Prepend;
538 }
539
540 CanCoalesce::Unavailable
541}
542
543fn checksum_valid(pkt: &[u8], iph_len: u8, proto: u8, is_v6: bool) -> bool {
544 let (src_addr_at, addr_size) = if is_v6 {
545 (IPV6_SRC_ADDR_OFFSET, 16)
546 } else {
547 (IPV4_SRC_ADDR_OFFSET, 4)
548 };
549
550 let len_for_pseudo = (pkt.len() as u16).saturating_sub(iph_len as u16);
551
552 let c_sum = pseudo_header_checksum_no_fold(
553 proto,
554 &pkt[src_addr_at..src_addr_at + addr_size],
555 &pkt[src_addr_at + addr_size..src_addr_at + addr_size * 2],
556 len_for_pseudo,
557 );
558
559 !checksum(&pkt[iph_len as usize..], c_sum) == 0
560}
561
562enum CoalesceResult {
565 InsufficientCap,
566 PSHEnding,
567 ItemInvalidCSum,
568 PktInvalidCSum,
569 Success,
570}
571
572fn coalesce_udp_packets<B: ExpandBuffer>(
575 pkt: &[u8],
576 item: &mut UdpGROItem,
577 bufs: &mut [B],
578 bufs_offset: usize,
579 is_v6: bool,
580) -> CoalesceResult {
581 let buf = bufs[item.bufs_index as usize].as_ref();
582 let headers_len = item.iph_len as usize + UDP_H_LEN;
584 let coalesced_len = buf[bufs_offset..].len() + pkt.len() - headers_len;
585 if bufs[item.bufs_index as usize].buf_capacity() < bufs_offset * 2 + coalesced_len {
586 return CoalesceResult::InsufficientCap;
589 }
590
591 if item.num_merged == 0
592 && (item.c_sum_known_invalid
593 || !checksum_valid(&buf[bufs_offset..], item.iph_len, IPPROTO_UDP as _, is_v6))
594 {
595 return CoalesceResult::ItemInvalidCSum;
596 }
597
598 if !checksum_valid(pkt, item.iph_len, IPPROTO_UDP as _, is_v6) {
599 return CoalesceResult::PktInvalidCSum;
600 }
601 bufs[item.bufs_index as usize].buf_extend_from_slice(&pkt[headers_len..]);
602 item.num_merged += 1;
603 CoalesceResult::Success
604}
605
606#[allow(clippy::too_many_arguments)]
611fn coalesce_tcp_packets<B: ExpandBuffer>(
612 mode: CanCoalesce,
613 pkt: &[u8],
614 pkt_bufs_index: usize,
615 gso_size: u16,
616 seq: u32,
617 psh_set: bool,
618 item: &mut TcpGROItem,
619 bufs: &mut [B],
620 bufs_offset: usize,
621 is_v6: bool,
622) -> CoalesceResult {
623 let pkt_head: &[u8]; let headers_len = (item.iph_len + item.tcph_len) as usize;
625 let coalesced_len =
626 bufs[item.bufs_index as usize].as_ref()[bufs_offset..].len() + pkt.len() - headers_len;
627 if mode == CanCoalesce::Prepend {
629 pkt_head = pkt;
630 if bufs[pkt_bufs_index].buf_capacity() < 2 * bufs_offset + coalesced_len {
631 return CoalesceResult::InsufficientCap;
634 }
635 if psh_set {
636 return CoalesceResult::PSHEnding;
637 }
638 if item.num_merged == 0
639 && !checksum_valid(
640 &bufs[item.bufs_index as usize].as_ref()[bufs_offset..],
641 item.iph_len,
642 IPPROTO_TCP as _,
643 is_v6,
644 )
645 {
646 return CoalesceResult::ItemInvalidCSum;
647 }
648 if !checksum_valid(pkt, item.iph_len, IPPROTO_TCP as _, is_v6) {
649 return CoalesceResult::PktInvalidCSum;
650 }
651 item.sent_seq = seq;
652 let extend_by = coalesced_len - pkt_head.len();
653 let len = bufs[pkt_bufs_index].as_ref().len();
654 bufs[pkt_bufs_index].buf_resize(len + extend_by, 0);
655 let src = bufs[item.bufs_index as usize].as_ref()[bufs_offset + headers_len..].as_ptr();
656 let dst = bufs[pkt_bufs_index].as_mut()[bufs_offset + pkt.len()..].as_mut_ptr();
657 unsafe {
658 std::ptr::copy_nonoverlapping(src, dst, extend_by);
659 }
660 bufs.swap(item.bufs_index as usize, pkt_bufs_index);
663 } else {
664 if bufs[item.bufs_index as usize].buf_capacity() < 2 * bufs_offset + coalesced_len {
666 return CoalesceResult::InsufficientCap;
669 }
670 if item.num_merged == 0
671 && !checksum_valid(
672 &bufs[item.bufs_index as usize].as_ref()[bufs_offset..],
673 item.iph_len,
674 IPPROTO_TCP as _,
675 is_v6,
676 )
677 {
678 return CoalesceResult::ItemInvalidCSum;
679 }
680 if !checksum_valid(pkt, item.iph_len, IPPROTO_TCP as _, is_v6) {
681 return CoalesceResult::PktInvalidCSum;
682 }
683 if psh_set {
684 item.psh_set = psh_set;
686 bufs[item.bufs_index as usize].as_mut()
687 [bufs_offset + item.iph_len as usize + TCP_FLAGS_OFFSET] |= TCP_FLAG_PSH;
688 }
689 bufs[item.bufs_index as usize].buf_extend_from_slice(&pkt[headers_len..]);
694 }
695
696 if gso_size > item.gso_size {
697 item.gso_size = gso_size;
698 }
699
700 item.num_merged += 1;
701 CoalesceResult::Success
702}
703
704const IPV4_FLAG_MORE_FRAGMENTS: u8 = 0x20;
705
706const IPV4_SRC_ADDR_OFFSET: usize = 12;
707const IPV6_SRC_ADDR_OFFSET: usize = 8;
708#[derive(PartialEq, Eq)]
711enum GroResult {
712 Noop,
713 TableInsert,
714 Coalesced,
715}
716
717fn tcp_gro<B: ExpandBuffer>(
723 bufs: &mut [B],
724 offset: usize,
725 pkt_i: usize,
726 table: &mut TcpGROTable,
727 is_v6: bool,
728) -> GroResult {
729 let pkt = unsafe { &*(&bufs[pkt_i].as_ref()[offset..] as *const [u8]) };
730 if pkt.len() > u16::MAX as usize {
731 return GroResult::Noop;
733 }
734
735 let mut iph_len = ((pkt[0] & 0x0F) * 4) as usize;
736 if is_v6 {
737 iph_len = 40;
738 let ipv6_h_payload_len = u16::from_be_bytes([pkt[4], pkt[5]]) as usize;
739 if ipv6_h_payload_len != pkt.len() - iph_len {
740 return GroResult::Noop;
741 }
742 } else {
743 let total_len = u16::from_be_bytes([pkt[2], pkt[3]]) as usize;
744 if total_len != pkt.len() {
745 return GroResult::Noop;
746 }
747 }
748
749 if pkt.len() < iph_len {
750 return GroResult::Noop;
751 }
752
753 let tcph_len = ((pkt[iph_len + 12] >> 4) * 4) as usize;
754 if !(20..=60).contains(&tcph_len) {
755 return GroResult::Noop;
756 }
757
758 if pkt.len() < iph_len + tcph_len {
759 return GroResult::Noop;
760 }
761
762 if !is_v6 && (pkt[6] & IPV4_FLAG_MORE_FRAGMENTS != 0 || pkt[6] << 3 != 0 || pkt[7] != 0) {
763 return GroResult::Noop;
765 }
766
767 let tcp_flags = pkt[iph_len + TCP_FLAGS_OFFSET];
768 let mut psh_set = false;
769
770 if tcp_flags != TCP_FLAG_ACK {
772 if pkt[iph_len + TCP_FLAGS_OFFSET] != TCP_FLAG_ACK | TCP_FLAG_PSH {
773 return GroResult::Noop;
774 }
775 psh_set = true;
776 }
777
778 let gso_size = (pkt.len() - tcph_len - iph_len) as u16;
779 if gso_size < 1 {
781 return GroResult::Noop;
782 }
783
784 let seq = u32::from_be_bytes([
785 pkt[iph_len + 4],
786 pkt[iph_len + 5],
787 pkt[iph_len + 6],
788 pkt[iph_len + 7],
789 ]);
790
791 let mut src_addr_offset = IPV4_SRC_ADDR_OFFSET;
792 let mut addr_len = 4;
793 if is_v6 {
794 src_addr_offset = IPV6_SRC_ADDR_OFFSET;
795 addr_len = 16;
796 }
797
798 let items = if let Some(items) = table.lookup_or_insert(
799 pkt,
800 src_addr_offset,
801 src_addr_offset + addr_len,
802 iph_len,
803 tcph_len,
804 pkt_i,
805 ) {
806 items
807 } else {
808 return GroResult::TableInsert;
809 };
810
811 for i in (0..items.len()).rev() {
812 let item = &mut items[i];
821 let can = tcp_packets_can_coalesce(
822 pkt,
823 iph_len as u8,
824 tcph_len as u8,
825 seq,
826 psh_set,
827 gso_size,
828 item,
829 bufs,
830 offset,
831 );
832
833 match can {
834 CanCoalesce::Unavailable => {}
835 _ => {
836 let result = coalesce_tcp_packets(
837 can, pkt, pkt_i, gso_size, seq, psh_set, item, bufs, offset, is_v6,
838 );
839
840 match result {
841 CoalesceResult::Success => {
842 return GroResult::Coalesced;
844 }
845 CoalesceResult::ItemInvalidCSum => {
846 items.remove(i);
849 }
850 CoalesceResult::PktInvalidCSum => {
851 return GroResult::Noop;
853 }
854 _ => {}
855 }
856 }
857 }
858 }
859
860 table.insert(
862 pkt,
863 src_addr_offset,
864 src_addr_offset + addr_len,
865 iph_len,
866 tcph_len,
867 pkt_i,
868 );
869 GroResult::TableInsert
870}
871
872pub fn apply_tcp_coalesce_accounting<B: ExpandBuffer>(
875 bufs: &mut [B],
876 offset: usize,
877 table: &TcpGROTable,
878) -> io::Result<()> {
879 for items in table.items_by_flow.values() {
880 for item in items {
881 if item.num_merged > 0 {
882 let mut hdr = VirtioNetHdr {
883 flags: VIRTIO_NET_HDR_F_NEEDS_CSUM,
884 hdr_len: (item.iph_len + item.tcph_len) as u16,
885 gso_size: item.gso_size,
886 csum_start: item.iph_len as u16,
887 csum_offset: 16,
888 gso_type: 0, };
890 let buf = bufs[item.bufs_index as usize].as_mut();
891 let pkt = &mut buf[offset..];
892 let pkt_len = pkt.len();
893
894 let addr_len = if item.key.is_v6 { 16 } else { 4 };
898 let addr_offset = if item.key.is_v6 {
899 IPV6_SRC_ADDR_OFFSET
900 } else {
901 IPV4_SRC_ADDR_OFFSET
902 };
903
904 let src_addr_at = offset + addr_offset;
905 let src_addr =
906 unsafe { &*(&pkt[src_addr_at..src_addr_at + addr_len] as *const [u8]) };
907 let dst_addr = unsafe {
908 &*(&pkt[src_addr_at + addr_len..src_addr_at + addr_len * 2] as *const [u8])
909 };
910 if item.key.is_v6 {
913 hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV6;
914 BigEndian::write_u16(&mut pkt[4..6], pkt_len as u16 - item.iph_len as u16);
915 } else {
916 hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV4;
917 pkt[10] = 0;
918 pkt[11] = 0;
919 BigEndian::write_u16(&mut pkt[2..4], pkt_len as u16);
920 let iph_csum = !checksum(&pkt[..item.iph_len as usize], 0);
921 BigEndian::write_u16(&mut pkt[10..12], iph_csum);
922 }
923
924 hdr.encode(&mut buf[offset - VIRTIO_NET_HDR_LEN..])?;
925
926 let pkt = &mut buf[offset..];
927
928 let psum = pseudo_header_checksum_no_fold(
929 IPPROTO_TCP as _,
930 src_addr,
931 dst_addr,
932 pkt_len as u16 - item.iph_len as u16,
933 );
934 let tcp_csum = checksum(&[], psum);
935 BigEndian::write_u16(
936 &mut pkt[(hdr.csum_start + hdr.csum_offset) as usize..],
937 tcp_csum,
938 );
939 } else {
940 let hdr = VirtioNetHdr::default();
941 hdr.encode(
942 &mut bufs[item.bufs_index as usize].as_mut()[offset - VIRTIO_NET_HDR_LEN..],
943 )?;
944 }
945 }
946 }
947 Ok(())
948}
949
950pub fn apply_udp_coalesce_accounting<B: ExpandBuffer>(
953 bufs: &mut [B],
954 offset: usize,
955 table: &UdpGROTable,
956) -> io::Result<()> {
957 for items in table.items_by_flow.values() {
958 for item in items {
959 if item.num_merged > 0 {
960 let hdr = VirtioNetHdr {
961 flags: VIRTIO_NET_HDR_F_NEEDS_CSUM, hdr_len: item.iph_len as u16 + UDP_H_LEN as u16,
963 gso_size: item.gso_size,
964 csum_start: item.iph_len as u16,
965 csum_offset: 6,
966 gso_type: VIRTIO_NET_HDR_GSO_UDP_L4,
967 };
968
969 let buf = bufs[item.bufs_index as usize].as_mut();
970 let pkt = &mut buf[offset..];
971 let pkt_len = pkt.len();
972
973 let (addr_len, addr_offset) = if item.key.is_v6 {
977 (16, IPV6_SRC_ADDR_OFFSET)
978 } else {
979 (4, IPV4_SRC_ADDR_OFFSET)
980 };
981
982 let src_addr_at = offset + addr_offset;
983 let src_addr =
984 unsafe { &*(&pkt[src_addr_at..(src_addr_at + addr_len)] as *const [u8]) };
985 let dst_addr = unsafe {
986 &*(&pkt[(src_addr_at + addr_len)..(src_addr_at + addr_len * 2)]
987 as *const [u8])
988 };
989
990 if item.key.is_v6 {
993 BigEndian::write_u16(&mut pkt[4..6], pkt_len as u16 - item.iph_len as u16);
994 } else {
996 pkt[10] = 0;
997 pkt[11] = 0;
998 BigEndian::write_u16(&mut pkt[2..4], pkt_len as u16); let iph_csum = !checksum(&pkt[..item.iph_len as usize], 0);
1000 BigEndian::write_u16(&mut pkt[10..12], iph_csum); }
1002
1003 hdr.encode(&mut buf[offset - VIRTIO_NET_HDR_LEN..])?;
1004 let pkt = &mut buf[offset..];
1005 BigEndian::write_u16(
1007 &mut pkt[(item.iph_len as usize + 4)..(item.iph_len as usize + 6)],
1008 pkt_len as u16 - item.iph_len as u16,
1009 );
1010
1011 let psum = pseudo_header_checksum_no_fold(
1012 IPPROTO_UDP as _,
1013 src_addr,
1014 dst_addr,
1015 pkt_len as u16 - item.iph_len as u16,
1016 );
1017
1018 let udp_csum = checksum(&[], psum);
1019 BigEndian::write_u16(
1020 &mut pkt[(hdr.csum_start + hdr.csum_offset) as usize..],
1021 udp_csum,
1022 );
1023 } else {
1024 let hdr = VirtioNetHdr::default();
1025 hdr.encode(
1026 &mut bufs[item.bufs_index as usize].as_mut()[offset - VIRTIO_NET_HDR_LEN..],
1027 )?;
1028 }
1029 }
1030 }
1031 Ok(())
1032}
1033
1034#[derive(PartialEq, Eq)]
1035pub enum GroCandidateType {
1036 NotGRO,
1037 Tcp4GRO,
1038 Tcp6GRO,
1039 Udp4GRO,
1040 Udp6GRO,
1041}
1042
1043pub fn packet_is_gro_candidate(b: &[u8], can_udp_gro: bool) -> GroCandidateType {
1044 if b.len() < 28 {
1045 return GroCandidateType::NotGRO;
1046 }
1047 if b[0] >> 4 == 4 {
1048 if b[0] & 0x0F != 5 {
1049 return GroCandidateType::NotGRO;
1051 }
1052 match b[9] {
1053 6 if b.len() >= 40 => return GroCandidateType::Tcp4GRO,
1054 17 if can_udp_gro => return GroCandidateType::Udp4GRO,
1055 _ => {}
1056 }
1057 } else if b[0] >> 4 == 6 {
1058 match b[6] {
1059 6 if b.len() >= 60 => return GroCandidateType::Tcp6GRO,
1060 17 if b.len() >= 48 && can_udp_gro => return GroCandidateType::Udp6GRO,
1061 _ => {}
1062 }
1063 }
1064 GroCandidateType::NotGRO
1065}
1066
1067const UDP_H_LEN: usize = 8;
1068
1069fn udp_gro<B: ExpandBuffer>(
1075 bufs: &mut [B],
1076 offset: usize,
1077 pkt_i: usize,
1078 table: &mut UdpGROTable,
1079 is_v6: bool,
1080) -> GroResult {
1081 let pkt = unsafe { &*(&bufs[pkt_i].as_ref()[offset..] as *const [u8]) };
1082 if pkt.len() > u16::MAX as usize {
1083 return GroResult::Noop;
1085 }
1086
1087 let mut iph_len = ((pkt[0] & 0x0F) * 4) as usize;
1088 if is_v6 {
1089 iph_len = 40;
1090 let ipv6_payload_len = u16::from_be_bytes([pkt[4], pkt[5]]) as usize;
1091 if ipv6_payload_len != pkt.len() - iph_len {
1092 return GroResult::Noop;
1093 }
1094 } else {
1095 let total_len = u16::from_be_bytes([pkt[2], pkt[3]]) as usize;
1096 if total_len != pkt.len() {
1097 return GroResult::Noop;
1098 }
1099 }
1100
1101 if pkt.len() < iph_len || pkt.len() < iph_len + UDP_H_LEN {
1102 return GroResult::Noop;
1103 }
1104
1105 if !is_v6 && (pkt[6] & IPV4_FLAG_MORE_FRAGMENTS != 0 || pkt[6] << 3 != 0 || pkt[7] != 0) {
1106 return GroResult::Noop;
1108 }
1109
1110 let gso_size = (pkt.len() - UDP_H_LEN - iph_len) as u16;
1111 if gso_size < 1 {
1112 return GroResult::Noop;
1113 }
1114
1115 let (src_addr_offset, addr_len) = if is_v6 {
1116 (IPV6_SRC_ADDR_OFFSET, 16)
1117 } else {
1118 (IPV4_SRC_ADDR_OFFSET, 4)
1119 };
1120
1121 let items = table.lookup_or_insert(
1122 pkt,
1123 src_addr_offset,
1124 src_addr_offset + addr_len,
1125 iph_len,
1126 pkt_i,
1127 );
1128
1129 let items = if let Some(items) = items {
1130 items
1131 } else {
1132 return GroResult::TableInsert;
1133 };
1134
1135 let items_len = items.len();
1137 let item = &mut items[items_len - 1];
1138 let can = udp_packets_can_coalesce(pkt, iph_len as u8, gso_size, item, bufs, offset);
1139 let mut pkt_csum_known_invalid = false;
1140
1141 if can == CanCoalesce::Append {
1142 match coalesce_udp_packets(pkt, item, bufs, offset, is_v6) {
1143 CoalesceResult::Success => {
1144 return GroResult::Coalesced;
1147 }
1148 CoalesceResult::ItemInvalidCSum => {
1149 }
1152 CoalesceResult::PktInvalidCSum => {
1153 pkt_csum_known_invalid = true;
1155 }
1156 _ => {}
1157 }
1158 }
1159 let pkt = &bufs[pkt_i].as_ref()[offset..];
1160 table.insert(
1162 pkt,
1163 src_addr_offset,
1164 src_addr_offset + addr_len,
1165 iph_len,
1166 pkt_i,
1167 pkt_csum_known_invalid,
1168 );
1169 GroResult::TableInsert
1170}
1171
1172pub fn handle_gro<B: ExpandBuffer>(
1178 bufs: &mut [B],
1179 offset: usize,
1180 tcp_table: &mut TcpGROTable,
1181 udp_table: &mut UdpGROTable,
1182 can_udp_gro: bool,
1183 to_write: &mut Vec<usize>,
1184) -> io::Result<()> {
1185 let bufs_len = bufs.len();
1186 for i in 0..bufs_len {
1187 if offset < VIRTIO_NET_HDR_LEN || offset > bufs[i].as_ref().len() - 1 {
1188 return Err(io::Error::new(
1189 io::ErrorKind::InvalidInput,
1190 "invalid offset",
1191 ));
1192 }
1193
1194 let result = match packet_is_gro_candidate(&bufs[i].as_ref()[offset..], can_udp_gro) {
1195 GroCandidateType::Tcp4GRO => tcp_gro(bufs, offset, i, tcp_table, false),
1196 GroCandidateType::Tcp6GRO => tcp_gro(bufs, offset, i, tcp_table, true),
1197 GroCandidateType::Udp4GRO => udp_gro(bufs, offset, i, udp_table, false),
1198 GroCandidateType::Udp6GRO => udp_gro(bufs, offset, i, udp_table, true),
1199 GroCandidateType::NotGRO => GroResult::Noop,
1200 };
1201
1202 match result {
1203 GroResult::Noop => {
1204 let hdr = VirtioNetHdr::default();
1205 hdr.encode(&mut bufs[i].as_mut()[offset - VIRTIO_NET_HDR_LEN..offset])?;
1206 to_write.push(i);
1208 }
1209 GroResult::TableInsert => {
1210 to_write.push(i);
1211 }
1212 _ => {}
1213 }
1214 }
1215
1216 let err_tcp = apply_tcp_coalesce_accounting(bufs, offset, tcp_table);
1217 let err_udp = apply_udp_coalesce_accounting(bufs, offset, udp_table);
1218 err_tcp?;
1219 err_udp?;
1220 Ok(())
1221}
1222
1223pub fn gso_split<B: AsRef<[u8]> + AsMut<[u8]>>(
1227 input: &mut [u8],
1228 hdr: VirtioNetHdr,
1229 out_bufs: &mut [B],
1230 sizes: &mut [usize],
1231 out_offset: usize,
1232 is_v6: bool,
1233) -> io::Result<usize> {
1234 let iph_len = hdr.csum_start as usize;
1235 let (src_addr_offset, addr_len) = if is_v6 {
1236 (IPV6_SRC_ADDR_OFFSET, 16)
1237 } else {
1238 input[10] = 0;
1239 input[11] = 0; (IPV4_SRC_ADDR_OFFSET, 4)
1241 };
1242
1243 let transport_csum_at = (hdr.csum_start + hdr.csum_offset) as usize;
1244 input[transport_csum_at] = 0;
1245 input[transport_csum_at + 1] = 0; let (first_tcp_seq_num, protocol) =
1248 if hdr.gso_type == VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gso_type == VIRTIO_NET_HDR_GSO_TCPV6 {
1249 (
1250 BigEndian::read_u32(&input[hdr.csum_start as usize + 4..]),
1251 IPPROTO_TCP,
1252 )
1253 } else {
1254 (0, IPPROTO_UDP)
1255 };
1256
1257 let mut next_segment_data_at = hdr.hdr_len as usize;
1258 let mut i = 0;
1259
1260 while next_segment_data_at < input.len() {
1261 if i == out_bufs.len() {
1262 return Err(io::Error::other("ErrTooManySegments"));
1263 }
1264
1265 let mut next_segment_end = next_segment_data_at + hdr.gso_size as usize;
1266 if next_segment_end > input.len() {
1267 next_segment_end = input.len();
1268 }
1269 let segment_data_len = next_segment_end - next_segment_data_at;
1270 let total_len = hdr.hdr_len as usize + segment_data_len;
1271
1272 sizes[i] = total_len;
1273 let out = &mut out_bufs[i].as_mut()[out_offset..];
1274
1275 out[..iph_len].copy_from_slice(&input[..iph_len]);
1276
1277 if !is_v6 {
1278 if i > 0 {
1282 let id = BigEndian::read_u16(&out[4..]).wrapping_add(i as u16);
1283 BigEndian::write_u16(&mut out[4..6], id);
1284 }
1285 BigEndian::write_u16(&mut out[2..4], total_len as u16);
1286 let ipv4_csum = !checksum(&out[..iph_len], 0);
1287 BigEndian::write_u16(&mut out[10..12], ipv4_csum);
1288 } else {
1289 BigEndian::write_u16(&mut out[4..6], (total_len - iph_len) as u16);
1290 }
1291
1292 out[hdr.csum_start as usize..hdr.hdr_len as usize]
1293 .copy_from_slice(&input[hdr.csum_start as usize..hdr.hdr_len as usize]);
1294
1295 if protocol == IPPROTO_TCP {
1296 let tcp_seq = first_tcp_seq_num.wrapping_add(hdr.gso_size as u32 * i as u32);
1297 BigEndian::write_u32(
1298 &mut out[(hdr.csum_start + 4) as usize..(hdr.csum_start + 8) as usize],
1299 tcp_seq,
1300 );
1301 if next_segment_end != input.len() {
1302 out[hdr.csum_start as usize + TCP_FLAGS_OFFSET] &= !(TCP_FLAG_FIN | TCP_FLAG_PSH);
1303 }
1304 } else {
1305 let udp_len = (segment_data_len + (hdr.hdr_len - hdr.csum_start) as usize) as u16;
1306 BigEndian::write_u16(
1307 &mut out[(hdr.csum_start + 4) as usize..(hdr.csum_start + 6) as usize],
1308 udp_len,
1309 );
1310 }
1311
1312 out[hdr.hdr_len as usize..total_len]
1313 .as_mut()
1314 .copy_from_slice(&input[next_segment_data_at..next_segment_end]);
1315
1316 let transport_header_len = (hdr.hdr_len - hdr.csum_start) as usize;
1317 let len_for_pseudo = (transport_header_len + segment_data_len) as u16;
1318 let transport_csum_no_fold = pseudo_header_checksum_no_fold(
1319 protocol as u8,
1320 &input[src_addr_offset..src_addr_offset + addr_len],
1321 &input[src_addr_offset + addr_len..src_addr_offset + 2 * addr_len],
1322 len_for_pseudo,
1323 );
1324 let transport_csum = !checksum(
1325 &out[hdr.csum_start as usize..total_len],
1326 transport_csum_no_fold,
1327 );
1328 BigEndian::write_u16(
1329 &mut out[transport_csum_at..transport_csum_at + 2],
1330 transport_csum,
1331 );
1332
1333 next_segment_data_at += hdr.gso_size as usize;
1334 i += 1;
1335 }
1336
1337 Ok(i)
1338}
1339
1340pub fn gso_none_checksum(in_buf: &mut [u8], csum_start: u16, csum_offset: u16) {
1341 let csum_at = (csum_start + csum_offset) as usize;
1342 let initial = BigEndian::read_u16(&in_buf[csum_at..]);
1345 in_buf[csum_at] = 0;
1346 in_buf[csum_at + 1] = 0;
1347 let computed_checksum = checksum(&in_buf[csum_start as usize..], initial as u64);
1348 BigEndian::write_u16(&mut in_buf[csum_at..], !computed_checksum);
1349}
1350
1351#[derive(Default)]
1353pub struct GROTable {
1354 pub(crate) to_write: Vec<usize>,
1355 pub(crate) tcp_gro_table: TcpGROTable,
1356 pub(crate) udp_gro_table: UdpGROTable,
1357}
1358
1359impl GROTable {
1360 pub fn new() -> GROTable {
1361 GROTable {
1362 to_write: Vec::with_capacity(IDEAL_BATCH_SIZE),
1363 tcp_gro_table: TcpGROTable::new(),
1364 udp_gro_table: UdpGROTable::new(),
1365 }
1366 }
1367 pub(crate) fn reset(&mut self) {
1368 self.to_write.clear();
1369 self.tcp_gro_table.reset();
1370 self.udp_gro_table.reset();
1371 }
1372}
1373
1374pub trait ExpandBuffer: AsRef<[u8]> + AsMut<[u8]> {
1375 fn buf_capacity(&self) -> usize;
1376 fn buf_resize(&mut self, new_len: usize, value: u8);
1377 fn buf_extend_from_slice(&mut self, src: &[u8]);
1378}
1379
1380impl ExpandBuffer for BytesMut {
1381 fn buf_capacity(&self) -> usize {
1382 self.capacity()
1383 }
1384
1385 fn buf_resize(&mut self, new_len: usize, value: u8) {
1386 self.resize(new_len, value)
1387 }
1388
1389 fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1390 self.extend_from_slice(extend)
1391 }
1392}
1393
1394impl ExpandBuffer for &mut BytesMut {
1395 fn buf_capacity(&self) -> usize {
1396 self.capacity()
1397 }
1398 fn buf_resize(&mut self, new_len: usize, value: u8) {
1399 self.resize(new_len, value)
1400 }
1401
1402 fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1403 self.extend_from_slice(extend)
1404 }
1405}
1406impl ExpandBuffer for Vec<u8> {
1407 fn buf_capacity(&self) -> usize {
1408 self.capacity()
1409 }
1410
1411 fn buf_resize(&mut self, new_len: usize, value: u8) {
1412 self.resize(new_len, value)
1413 }
1414
1415 fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1416 self.extend_from_slice(extend)
1417 }
1418}
1419impl ExpandBuffer for &mut Vec<u8> {
1420 fn buf_capacity(&self) -> usize {
1421 self.capacity()
1422 }
1423
1424 fn buf_resize(&mut self, new_len: usize, value: u8) {
1425 self.resize(new_len, value)
1426 }
1427
1428 fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1429 self.extend_from_slice(extend)
1430 }
1431}