Skip to main content

tun_rs/platform/linux/
offload.rs

1/// https://github.com/WireGuard/wireguard-go/blob/master/tun/offload_linux.go
2use crate::platform::linux::checksum::{checksum, pseudo_header_checksum_no_fold};
3use byteorder::{BigEndian, ByteOrder};
4use bytes::BytesMut;
5use libc::{IPPROTO_TCP, IPPROTO_UDP};
6use std::collections::HashMap;
7use std::io;
8
9/// https://github.com/torvalds/linux/blob/master/include/uapi/linux/virtio_net.h
10pub const VIRTIO_NET_HDR_GSO_NONE: u8 = 0;
11pub const VIRTIO_NET_HDR_F_NEEDS_CSUM: u8 = 1;
12pub const VIRTIO_NET_HDR_GSO_TCPV4: u8 = 1;
13pub const VIRTIO_NET_HDR_GSO_TCPV6: u8 = 4;
14pub const VIRTIO_NET_HDR_GSO_UDP_L4: u8 = 5;
15
16/// <https://github.com/WireGuard/wireguard-go/blob/master/conn/conn.go#L19>
17///
18/// maximum number of packets handled per read and write
19pub const IDEAL_BATCH_SIZE: usize = 128;
20
21const TCP_FLAGS_OFFSET: usize = 13;
22
23const TCP_FLAG_FIN: u8 = 0x01;
24const TCP_FLAG_PSH: u8 = 0x08;
25const TCP_FLAG_ACK: u8 = 0x10;
26
27///  virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
28/// kernel symbol is virtio_net_hdr.
29///
30/// https://github.com/torvalds/linux/blob/master/include/uapi/linux/virtio_net.h
31#[repr(C)]
32#[derive(Debug, Clone, Copy, Default)]
33pub struct VirtioNetHdr {
34    // #define VIRTIO_NET_HDR_F_NEEDS_CSUM	1	/* Use csum_start, csum_offset */
35    // #define VIRTIO_NET_HDR_F_DATA_VALID	2	/* Csum is valid */
36    // #define VIRTIO_NET_HDR_F_RSC_INFO	4	/* rsc info in csum_ fields */
37    pub flags: u8,
38    // #define VIRTIO_NET_HDR_GSO_NONE		0	/* Not a GSO frame */
39    // #define VIRTIO_NET_HDR_GSO_TCPV4	1	/* GSO frame, IPv4 TCP (TSO) */
40    // #define VIRTIO_NET_HDR_GSO_UDP		3	/* GSO frame, IPv4 UDP (UFO) */
41    // #define VIRTIO_NET_HDR_GSO_TCPV6	4	/* GSO frame, IPv6 TCP */
42    // #define VIRTIO_NET_HDR_GSO_UDP_L4	5	/* GSO frame, IPv4& IPv6 UDP (USO) */
43    // #define VIRTIO_NET_HDR_GSO_ECN		0x80	/* TCP has ECN set */
44    pub gso_type: u8,
45    // Ethernet + IP + tcp/udp hdrs
46    pub hdr_len: u16,
47    // Bytes to append to hdr_len per frame
48    pub gso_size: u16,
49    // Checksum calculation
50    pub csum_start: u16,
51    pub csum_offset: u16,
52}
53
54impl VirtioNetHdr {
55    pub fn decode(buf: &[u8]) -> io::Result<VirtioNetHdr> {
56        if buf.len() < VIRTIO_NET_HDR_LEN {
57            return Err(io::Error::new(io::ErrorKind::InvalidInput, "too short"));
58        }
59        let mut hdr = std::mem::MaybeUninit::<VirtioNetHdr>::uninit();
60        unsafe {
61            // Safety:
62            // hdr is written by `buf`, both pointers satisfy the alignment requirement of `u8`
63            std::ptr::copy_nonoverlapping(
64                buf.as_ptr(),
65                hdr.as_mut_ptr() as *mut _,
66                std::mem::size_of::<VirtioNetHdr>(),
67            );
68            Ok(hdr.assume_init())
69        }
70    }
71    pub fn encode(&self, buf: &mut [u8]) -> io::Result<()> {
72        if buf.len() < VIRTIO_NET_HDR_LEN {
73            return Err(io::Error::new(io::ErrorKind::InvalidInput, "too short"));
74        }
75        unsafe {
76            let hdr_ptr = self as *const VirtioNetHdr as *const u8;
77            std::ptr::copy_nonoverlapping(hdr_ptr, buf.as_mut_ptr(), VIRTIO_NET_HDR_LEN);
78            Ok(())
79        }
80    }
81}
82
83// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
84// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
85pub const VIRTIO_NET_HDR_LEN: usize = std::mem::size_of::<VirtioNetHdr>();
86
87/// tcpFlowKey represents the key for a TCP flow.
88#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
89pub struct TcpFlowKey {
90    src_addr: [u8; 16],
91    dst_addr: [u8; 16],
92    src_port: u16,
93    dst_port: u16,
94    rx_ack: u32, // varying ack values should not be coalesced. Treat them as separate flows.
95    is_v6: bool,
96}
97
98/// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
99pub struct TcpGROTable {
100    items_by_flow: HashMap<TcpFlowKey, Vec<TcpGROItem>>,
101    items_pool: Vec<Vec<TcpGROItem>>,
102}
103
104impl Default for TcpGROTable {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl TcpGROTable {
111    fn new() -> Self {
112        let mut items_pool = Vec::with_capacity(IDEAL_BATCH_SIZE);
113        for _ in 0..IDEAL_BATCH_SIZE {
114            items_pool.push(Vec::with_capacity(IDEAL_BATCH_SIZE));
115        }
116        TcpGROTable {
117            items_by_flow: HashMap::with_capacity(IDEAL_BATCH_SIZE),
118            items_pool,
119        }
120    }
121}
122
123impl TcpFlowKey {
124    fn new(pkt: &[u8], src_addr_offset: usize, dst_addr_offset: usize, tcph_offset: usize) -> Self {
125        let mut key = TcpFlowKey {
126            src_addr: [0; 16],
127            dst_addr: [0; 16],
128            src_port: 0,
129            dst_port: 0,
130            rx_ack: 0,
131            is_v6: false,
132        };
133
134        let addr_size = dst_addr_offset - src_addr_offset;
135        key.src_addr[..addr_size].copy_from_slice(&pkt[src_addr_offset..dst_addr_offset]);
136        key.dst_addr[..addr_size]
137            .copy_from_slice(&pkt[dst_addr_offset..dst_addr_offset + addr_size]);
138        key.src_port = BigEndian::read_u16(&pkt[tcph_offset..]);
139        key.dst_port = BigEndian::read_u16(&pkt[tcph_offset + 2..]);
140        key.rx_ack = BigEndian::read_u32(&pkt[tcph_offset + 8..]);
141        key.is_v6 = addr_size == 16;
142        key
143    }
144}
145
146impl TcpGROTable {
147    /// lookupOrInsert looks up a flow for the provided packet and metadata,
148    /// returning the packets found for the flow, or inserting a new one if none
149    /// is found.
150    fn lookup_or_insert(
151        &mut self,
152        pkt: &[u8],
153        src_addr_offset: usize,
154        dst_addr_offset: usize,
155        tcph_offset: usize,
156        tcph_len: usize,
157        bufs_index: usize,
158    ) -> Option<&mut Vec<TcpGROItem>> {
159        let key = TcpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, tcph_offset);
160        if self.items_by_flow.contains_key(&key) {
161            return self.items_by_flow.get_mut(&key);
162        }
163        // Insert the new item into the table
164        self.insert(
165            pkt,
166            src_addr_offset,
167            dst_addr_offset,
168            tcph_offset,
169            tcph_len,
170            bufs_index,
171        );
172        None
173    }
174    /// insert an item in the table for the provided packet and packet metadata.
175    fn insert(
176        &mut self,
177        pkt: &[u8],
178        src_addr_offset: usize,
179        dst_addr_offset: usize,
180        tcph_offset: usize,
181        tcph_len: usize,
182        bufs_index: usize,
183    ) {
184        let key = TcpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, tcph_offset);
185        let item = TcpGROItem {
186            key,
187            bufs_index: bufs_index as u16,
188            num_merged: 0,
189            gso_size: pkt[tcph_offset + tcph_len..].len() as u16,
190            iph_len: tcph_offset as u8,
191            tcph_len: tcph_len as u8,
192            sent_seq: BigEndian::read_u32(&pkt[tcph_offset + 4..tcph_offset + 8]),
193            psh_set: pkt[tcph_offset + TCP_FLAGS_OFFSET] & TCP_FLAG_PSH != 0,
194        };
195
196        let items = self
197            .items_by_flow
198            .entry(key)
199            .or_insert_with(|| self.items_pool.pop().unwrap_or_default());
200        items.push(item);
201    }
202}
203// func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
204// 	items, _ := t.itemsByFlow[item.key]
205// 	items[i] = item
206// }
207//
208// func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
209// 	items, _ := t.itemsByFlow[key]
210// 	items = append(items[:i], items[i+1:]...)
211// 	t.itemsByFlow[key] = items
212// }
213
214/// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
215/// of a GRO evaluation across a vector of packets.
216#[derive(Debug, Clone, Copy)]
217pub struct TcpGROItem {
218    key: TcpFlowKey,
219    sent_seq: u32,   // the sequence number
220    bufs_index: u16, // the index into the original bufs slice
221    num_merged: u16, // the number of packets merged into this item
222    gso_size: u16,   // payload size
223    iph_len: u8,     // ip header len
224    tcph_len: u8,    // tcp header len
225    psh_set: bool,   // psh flag is set
226}
227
228// func (t *tcpGROTable) newItems() []tcpGROItem {
229// 	var items []tcpGROItem
230// 	items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
231// 	return items
232// }
233impl TcpGROTable {
234    fn reset(&mut self) {
235        for (_key, mut items) in self.items_by_flow.drain() {
236            items.clear();
237            self.items_pool.push(items);
238        }
239    }
240}
241
242/// udpFlowKey represents the key for a UDP flow.
243#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
244pub struct UdpFlowKey {
245    src_addr: [u8; 16], // srcAddr
246    dst_addr: [u8; 16], // dstAddr
247    src_port: u16,      // srcPort
248    dst_port: u16,      // dstPort
249    is_v6: bool,        // isV6
250}
251
252///  udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
253pub struct UdpGROTable {
254    items_by_flow: HashMap<UdpFlowKey, Vec<UdpGROItem>>,
255    items_pool: Vec<Vec<UdpGROItem>>,
256}
257
258impl Default for UdpGROTable {
259    fn default() -> Self {
260        UdpGROTable::new()
261    }
262}
263
264impl UdpGROTable {
265    pub fn new() -> Self {
266        let mut items_pool = Vec::with_capacity(IDEAL_BATCH_SIZE);
267        for _ in 0..IDEAL_BATCH_SIZE {
268            items_pool.push(Vec::with_capacity(IDEAL_BATCH_SIZE));
269        }
270        UdpGROTable {
271            items_by_flow: HashMap::with_capacity(IDEAL_BATCH_SIZE),
272            items_pool,
273        }
274    }
275}
276
277impl UdpFlowKey {
278    pub fn new(
279        pkt: &[u8],
280        src_addr_offset: usize,
281        dst_addr_offset: usize,
282        udph_offset: usize,
283    ) -> UdpFlowKey {
284        let mut key = UdpFlowKey {
285            src_addr: [0; 16],
286            dst_addr: [0; 16],
287            src_port: 0,
288            dst_port: 0,
289            is_v6: false,
290        };
291        let addr_size = dst_addr_offset - src_addr_offset;
292        key.src_addr[..addr_size].copy_from_slice(&pkt[src_addr_offset..dst_addr_offset]);
293        key.dst_addr[..addr_size]
294            .copy_from_slice(&pkt[dst_addr_offset..dst_addr_offset + addr_size]);
295        key.src_port = BigEndian::read_u16(&pkt[udph_offset..]);
296        key.dst_port = BigEndian::read_u16(&pkt[udph_offset + 2..]);
297        key.is_v6 = addr_size == 16;
298        key
299    }
300}
301
302impl UdpGROTable {
303    /// Looks up a flow for the provided packet and metadata.
304    /// Returns a reference to the packets found for the flow and a boolean indicating if the flow already existed.
305    /// If the flow is not found, inserts a new flow and returns `None` for the items.
306    fn lookup_or_insert(
307        &mut self,
308        pkt: &[u8],
309        src_addr_offset: usize,
310        dst_addr_offset: usize,
311        udph_offset: usize,
312        bufs_index: usize,
313    ) -> Option<&mut Vec<UdpGROItem>> {
314        let key = UdpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, udph_offset);
315        if self.items_by_flow.contains_key(&key) {
316            self.items_by_flow.get_mut(&key)
317        } else {
318            // If the flow does not exist, insert a new entry.
319            self.insert(
320                pkt,
321                src_addr_offset,
322                dst_addr_offset,
323                udph_offset,
324                bufs_index,
325                false,
326            );
327            None
328        }
329    }
330    /// Inserts an item in the table for the provided packet and its metadata.
331    fn insert(
332        &mut self,
333        pkt: &[u8],
334        src_addr_offset: usize,
335        dst_addr_offset: usize,
336        udph_offset: usize,
337        bufs_index: usize,
338        c_sum_known_invalid: bool,
339    ) {
340        let key = UdpFlowKey::new(pkt, src_addr_offset, dst_addr_offset, udph_offset);
341        let item = UdpGROItem {
342            key,
343            bufs_index: bufs_index as u16,
344            num_merged: 0,
345            gso_size: (pkt.len() - (udph_offset + UDP_H_LEN)) as u16,
346            iph_len: udph_offset as u8,
347            c_sum_known_invalid,
348        };
349        let items = self
350            .items_by_flow
351            .entry(key)
352            .or_insert_with(|| self.items_pool.pop().unwrap_or_default());
353        items.push(item);
354    }
355}
356// func (u *udpGROTable) updateAt(item udpGROItem, i int) {
357// 	items, _ := u.itemsByFlow[item.key]
358// 	items[i] = item
359// }
360
361/// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
362/// of a GRO evaluation across a vector of packets.
363#[derive(Debug, Clone, Copy)]
364pub struct UdpGROItem {
365    key: UdpFlowKey,           // udpFlowKey
366    bufs_index: u16,           // the index into the original bufs slice
367    num_merged: u16,           // the number of packets merged into this item
368    gso_size: u16,             // payload size
369    iph_len: u8,               // ip header len
370    c_sum_known_invalid: bool, // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
371}
372// func (u *udpGROTable) newItems() []udpGROItem {
373// 	var items []udpGROItem
374// 	items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
375// 	return items
376// }
377
378impl UdpGROTable {
379    fn reset(&mut self) {
380        for (_key, mut items) in self.items_by_flow.drain() {
381            items.clear();
382            self.items_pool.push(items);
383        }
384    }
385}
386
387/// canCoalesce represents the outcome of checking if two TCP packets are
388/// candidates for coalescing.
389#[derive(Copy, Clone, Eq, PartialEq)]
390enum CanCoalesce {
391    Prepend,
392    Unavailable,
393    Append,
394}
395
396/// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
397/// meet all requirements to be merged as part of a GRO operation, otherwise it
398/// returns false.
399fn ip_headers_can_coalesce(pkt_a: &[u8], pkt_b: &[u8]) -> bool {
400    if pkt_a.len() < 9 || pkt_b.len() < 9 {
401        return false;
402    }
403
404    if pkt_a[0] >> 4 == 6 {
405        if pkt_a[0] != pkt_b[0] || pkt_a[1] >> 4 != pkt_b[1] >> 4 {
406            // cannot coalesce with unequal Traffic class values
407            return false;
408        }
409        if pkt_a[7] != pkt_b[7] {
410            // cannot coalesce with unequal Hop limit values
411            return false;
412        }
413    } else {
414        if pkt_a[1] != pkt_b[1] {
415            // cannot coalesce with unequal ToS values
416            return false;
417        }
418        if pkt_a[6] >> 5 != pkt_b[6] >> 5 {
419            // cannot coalesce with unequal DF or reserved bits. MF is checked
420            // further up the stack.
421            return false;
422        }
423        if pkt_a[8] != pkt_b[8] {
424            // cannot coalesce with unequal TTL values
425            return false;
426        }
427    }
428
429    true
430}
431
432/// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
433/// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
434/// packets involved in the current GRO evaluation. bufsOffset is the offset at
435/// which packet data begins within bufs.
436fn udp_packets_can_coalesce<B: ExpandBuffer>(
437    pkt: &[u8],
438    iph_len: u8,
439    gso_size: u16,
440    item: &UdpGROItem,
441    bufs: &[B],
442    bufs_offset: usize,
443) -> CanCoalesce {
444    let pkt_target = &bufs[item.bufs_index as usize].as_ref()[bufs_offset..];
445    if !ip_headers_can_coalesce(pkt, pkt_target) {
446        return CanCoalesce::Unavailable;
447    }
448    if (pkt_target[(iph_len as usize + UDP_H_LEN)..].len()) % (item.gso_size as usize) != 0 {
449        // A smaller than gsoSize packet has been appended previously.
450        // Nothing can come after a smaller packet on the end.
451        return CanCoalesce::Unavailable;
452    }
453    if gso_size > item.gso_size {
454        // We cannot have a larger packet following a smaller one.
455        return CanCoalesce::Unavailable;
456    }
457    CanCoalesce::Append
458}
459
460/// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
461/// described by item. This function makes considerations that match the kernel's
462/// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
463#[allow(clippy::too_many_arguments)]
464fn tcp_packets_can_coalesce<B: ExpandBuffer>(
465    pkt: &[u8],
466    iph_len: u8,
467    tcph_len: u8,
468    seq: u32,
469    psh_set: bool,
470    gso_size: u16,
471    item: &TcpGROItem,
472    bufs: &[B],
473    bufs_offset: usize,
474) -> CanCoalesce {
475    let pkt_target = &bufs[item.bufs_index as usize].as_ref()[bufs_offset..];
476
477    if tcph_len != item.tcph_len {
478        // cannot coalesce with unequal tcp options len
479        return CanCoalesce::Unavailable;
480    }
481
482    if tcph_len > 20
483        && pkt[iph_len as usize + 20..iph_len as usize + tcph_len as usize]
484            != pkt_target[item.iph_len as usize + 20..item.iph_len as usize + tcph_len as usize]
485    {
486        // cannot coalesce with unequal tcp options
487        return CanCoalesce::Unavailable;
488    }
489
490    if !ip_headers_can_coalesce(pkt, pkt_target) {
491        return CanCoalesce::Unavailable;
492    }
493
494    // seq adjacency
495    let mut lhs_len = item.gso_size as usize;
496    lhs_len += (item.num_merged as usize) * (item.gso_size as usize);
497
498    if seq == item.sent_seq.wrapping_add(lhs_len as u32) {
499        // pkt aligns following item from a seq num perspective
500        if item.psh_set {
501            // We cannot append to a segment that has the PSH flag set, PSH
502            // can only be set on the final segment in a reassembled group.
503            return CanCoalesce::Unavailable;
504        }
505
506        if pkt_target[iph_len as usize + tcph_len as usize..].len() % item.gso_size as usize != 0 {
507            // A smaller than gsoSize packet has been appended previously.
508            // Nothing can come after a smaller packet on the end.
509            return CanCoalesce::Unavailable;
510        }
511
512        if gso_size > item.gso_size {
513            // We cannot have a larger packet following a smaller one.
514            return CanCoalesce::Unavailable;
515        }
516
517        return CanCoalesce::Append;
518    } else if seq.wrapping_add(gso_size as u32) == item.sent_seq {
519        // pkt aligns in front of item from a seq num perspective
520        if psh_set {
521            // We cannot prepend with a segment that has the PSH flag set, PSH
522            // can only be set on the final segment in a reassembled group.
523            return CanCoalesce::Unavailable;
524        }
525
526        if gso_size < item.gso_size {
527            // We cannot have a larger packet following a smaller one.
528            return CanCoalesce::Unavailable;
529        }
530
531        if gso_size > item.gso_size && item.num_merged > 0 {
532            // There's at least one previous merge, and we're larger than all
533            // previous. This would put multiple smaller packets on the end.
534            return CanCoalesce::Unavailable;
535        }
536
537        return CanCoalesce::Prepend;
538    }
539
540    CanCoalesce::Unavailable
541}
542
543fn checksum_valid(pkt: &[u8], iph_len: u8, proto: u8, is_v6: bool) -> bool {
544    let (src_addr_at, addr_size) = if is_v6 {
545        (IPV6_SRC_ADDR_OFFSET, 16)
546    } else {
547        (IPV4_SRC_ADDR_OFFSET, 4)
548    };
549
550    let len_for_pseudo = (pkt.len() as u16).saturating_sub(iph_len as u16);
551
552    let c_sum = pseudo_header_checksum_no_fold(
553        proto,
554        &pkt[src_addr_at..src_addr_at + addr_size],
555        &pkt[src_addr_at + addr_size..src_addr_at + addr_size * 2],
556        len_for_pseudo,
557    );
558
559    !checksum(&pkt[iph_len as usize..], c_sum) == 0
560}
561
562/// coalesceResult represents the result of attempting to coalesce two TCP
563/// packets.
564enum CoalesceResult {
565    InsufficientCap,
566    PSHEnding,
567    ItemInvalidCSum,
568    PktInvalidCSum,
569    Success,
570}
571
572/// coalesceUDPPackets attempts to coalesce pkt with the packet described by
573/// item, and returns the outcome.
574fn coalesce_udp_packets<B: ExpandBuffer>(
575    pkt: &[u8],
576    item: &mut UdpGROItem,
577    bufs: &mut [B],
578    bufs_offset: usize,
579    is_v6: bool,
580) -> CoalesceResult {
581    let buf = bufs[item.bufs_index as usize].as_ref();
582    // let pkt_head = &buf[bufs_offset..]; // the packet that will end up at the front
583    let headers_len = item.iph_len as usize + UDP_H_LEN;
584    let coalesced_len = buf[bufs_offset..].len() + pkt.len() - headers_len;
585    if bufs[item.bufs_index as usize].buf_capacity() < bufs_offset * 2 + coalesced_len {
586        // We don't want to allocate a new underlying array if capacity is
587        // too small.
588        return CoalesceResult::InsufficientCap;
589    }
590
591    if item.num_merged == 0
592        && (item.c_sum_known_invalid
593            || !checksum_valid(&buf[bufs_offset..], item.iph_len, IPPROTO_UDP as _, is_v6))
594    {
595        return CoalesceResult::ItemInvalidCSum;
596    }
597
598    if !checksum_valid(pkt, item.iph_len, IPPROTO_UDP as _, is_v6) {
599        return CoalesceResult::PktInvalidCSum;
600    }
601    bufs[item.bufs_index as usize].buf_extend_from_slice(&pkt[headers_len..]);
602    item.num_merged += 1;
603    CoalesceResult::Success
604}
605
606/// coalesceTCPPackets attempts to coalesce pkt with the packet described by
607/// item, and returns the outcome. This function may swap bufs elements in the
608/// event of a prepend as item's bufs index is already being tracked for writing
609/// to a Device.
610#[allow(clippy::too_many_arguments)]
611fn coalesce_tcp_packets<B: ExpandBuffer>(
612    mode: CanCoalesce,
613    pkt: &[u8],
614    pkt_bufs_index: usize,
615    gso_size: u16,
616    seq: u32,
617    psh_set: bool,
618    item: &mut TcpGROItem,
619    bufs: &mut [B],
620    bufs_offset: usize,
621    is_v6: bool,
622) -> CoalesceResult {
623    let pkt_head: &[u8]; // the packet that will end up at the front
624    let headers_len = (item.iph_len + item.tcph_len) as usize;
625    let coalesced_len =
626        bufs[item.bufs_index as usize].as_ref()[bufs_offset..].len() + pkt.len() - headers_len;
627    // Copy data
628    if mode == CanCoalesce::Prepend {
629        pkt_head = pkt;
630        if bufs[pkt_bufs_index].buf_capacity() < 2 * bufs_offset + coalesced_len {
631            // We don't want to allocate a new underlying array if capacity is
632            // too small.
633            return CoalesceResult::InsufficientCap;
634        }
635        if psh_set {
636            return CoalesceResult::PSHEnding;
637        }
638        if item.num_merged == 0
639            && !checksum_valid(
640                &bufs[item.bufs_index as usize].as_ref()[bufs_offset..],
641                item.iph_len,
642                IPPROTO_TCP as _,
643                is_v6,
644            )
645        {
646            return CoalesceResult::ItemInvalidCSum;
647        }
648        if !checksum_valid(pkt, item.iph_len, IPPROTO_TCP as _, is_v6) {
649            return CoalesceResult::PktInvalidCSum;
650        }
651        item.sent_seq = seq;
652        let extend_by = coalesced_len - pkt_head.len();
653        let len = bufs[pkt_bufs_index].as_ref().len();
654        bufs[pkt_bufs_index].buf_resize(len + extend_by, 0);
655        let src = bufs[item.bufs_index as usize].as_ref()[bufs_offset + headers_len..].as_ptr();
656        let dst = bufs[pkt_bufs_index].as_mut()[bufs_offset + pkt.len()..].as_mut_ptr();
657        unsafe {
658            std::ptr::copy_nonoverlapping(src, dst, extend_by);
659        }
660        // Flip the slice headers in bufs as part of prepend. The index of item
661        // is already being tracked for writing.
662        bufs.swap(item.bufs_index as usize, pkt_bufs_index);
663    } else {
664        // pkt_head = &bufs[item.bufs_index as usize][bufs_offset..];
665        if bufs[item.bufs_index as usize].buf_capacity() < 2 * bufs_offset + coalesced_len {
666            // We don't want to allocate a new underlying array if capacity is
667            // too small.
668            return CoalesceResult::InsufficientCap;
669        }
670        if item.num_merged == 0
671            && !checksum_valid(
672                &bufs[item.bufs_index as usize].as_ref()[bufs_offset..],
673                item.iph_len,
674                IPPROTO_TCP as _,
675                is_v6,
676            )
677        {
678            return CoalesceResult::ItemInvalidCSum;
679        }
680        if !checksum_valid(pkt, item.iph_len, IPPROTO_TCP as _, is_v6) {
681            return CoalesceResult::PktInvalidCSum;
682        }
683        if psh_set {
684            // We are appending a segment with PSH set.
685            item.psh_set = psh_set;
686            bufs[item.bufs_index as usize].as_mut()
687                [bufs_offset + item.iph_len as usize + TCP_FLAGS_OFFSET] |= TCP_FLAG_PSH;
688        }
689        // https://github.com/WireGuard/wireguard-go/blob/12269c2761734b15625017d8565745096325392f/tun/offload_linux.go#L495
690        // extendBy := len(pkt) - int(headersLen)
691        // 		bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
692        // 		copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
693        bufs[item.bufs_index as usize].buf_extend_from_slice(&pkt[headers_len..]);
694    }
695
696    if gso_size > item.gso_size {
697        item.gso_size = gso_size;
698    }
699
700    item.num_merged += 1;
701    CoalesceResult::Success
702}
703
704const IPV4_FLAG_MORE_FRAGMENTS: u8 = 0x20;
705
706const IPV4_SRC_ADDR_OFFSET: usize = 12;
707const IPV6_SRC_ADDR_OFFSET: usize = 8;
708// maxUint16         = 1<<16 - 1
709
710#[derive(PartialEq, Eq)]
711enum GroResult {
712    Noop,
713    TableInsert,
714    Coalesced,
715}
716
717/// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
718/// existing packets tracked in table. It returns a groResultNoop when no
719/// action was taken, groResultTableInsert when the evaluated packet was
720/// inserted into table, and groResultCoalesced when the evaluated packet was
721/// coalesced with another packet in table.
722fn tcp_gro<B: ExpandBuffer>(
723    bufs: &mut [B],
724    offset: usize,
725    pkt_i: usize,
726    table: &mut TcpGROTable,
727    is_v6: bool,
728) -> GroResult {
729    let pkt = unsafe { &*(&bufs[pkt_i].as_ref()[offset..] as *const [u8]) };
730    if pkt.len() > u16::MAX as usize {
731        // A valid IPv4 or IPv6 packet will never exceed this.
732        return GroResult::Noop;
733    }
734
735    let mut iph_len = ((pkt[0] & 0x0F) * 4) as usize;
736    if is_v6 {
737        iph_len = 40;
738        let ipv6_h_payload_len = u16::from_be_bytes([pkt[4], pkt[5]]) as usize;
739        if ipv6_h_payload_len != pkt.len() - iph_len {
740            return GroResult::Noop;
741        }
742    } else {
743        let total_len = u16::from_be_bytes([pkt[2], pkt[3]]) as usize;
744        if total_len != pkt.len() {
745            return GroResult::Noop;
746        }
747    }
748
749    if pkt.len() < iph_len {
750        return GroResult::Noop;
751    }
752
753    let tcph_len = ((pkt[iph_len + 12] >> 4) * 4) as usize;
754    if !(20..=60).contains(&tcph_len) {
755        return GroResult::Noop;
756    }
757
758    if pkt.len() < iph_len + tcph_len {
759        return GroResult::Noop;
760    }
761
762    if !is_v6 && (pkt[6] & IPV4_FLAG_MORE_FRAGMENTS != 0 || pkt[6] << 3 != 0 || pkt[7] != 0) {
763        // no GRO support for fragmented segments for now
764        return GroResult::Noop;
765    }
766
767    let tcp_flags = pkt[iph_len + TCP_FLAGS_OFFSET];
768    let mut psh_set = false;
769
770    // not a candidate if any non-ACK flags (except PSH+ACK) are set
771    if tcp_flags != TCP_FLAG_ACK {
772        if pkt[iph_len + TCP_FLAGS_OFFSET] != TCP_FLAG_ACK | TCP_FLAG_PSH {
773            return GroResult::Noop;
774        }
775        psh_set = true;
776    }
777
778    let gso_size = (pkt.len() - tcph_len - iph_len) as u16;
779    // not a candidate if payload len is 0
780    if gso_size < 1 {
781        return GroResult::Noop;
782    }
783
784    let seq = u32::from_be_bytes([
785        pkt[iph_len + 4],
786        pkt[iph_len + 5],
787        pkt[iph_len + 6],
788        pkt[iph_len + 7],
789    ]);
790
791    let mut src_addr_offset = IPV4_SRC_ADDR_OFFSET;
792    let mut addr_len = 4;
793    if is_v6 {
794        src_addr_offset = IPV6_SRC_ADDR_OFFSET;
795        addr_len = 16;
796    }
797
798    let items = if let Some(items) = table.lookup_or_insert(
799        pkt,
800        src_addr_offset,
801        src_addr_offset + addr_len,
802        iph_len,
803        tcph_len,
804        pkt_i,
805    ) {
806        items
807    } else {
808        return GroResult::TableInsert;
809    };
810
811    for i in (0..items.len()).rev() {
812        // In the best case of packets arriving in order iterating in reverse is
813        // more efficient if there are multiple items for a given flow. This
814        // also enables a natural table.delete_at() in the
815        // coalesce_item_invalid_csum case without the need for index tracking.
816        // This algorithm makes a best effort to coalesce in the event of
817        // unordered packets, where pkt may land anywhere in items from a
818        // sequence number perspective, however once an item is inserted into
819        // the table it is never compared across other items later.
820        let item = &mut items[i];
821        let can = tcp_packets_can_coalesce(
822            pkt,
823            iph_len as u8,
824            tcph_len as u8,
825            seq,
826            psh_set,
827            gso_size,
828            item,
829            bufs,
830            offset,
831        );
832
833        match can {
834            CanCoalesce::Unavailable => {}
835            _ => {
836                let result = coalesce_tcp_packets(
837                    can, pkt, pkt_i, gso_size, seq, psh_set, item, bufs, offset, is_v6,
838                );
839
840                match result {
841                    CoalesceResult::Success => {
842                        // table.update_at(item, i);
843                        return GroResult::Coalesced;
844                    }
845                    CoalesceResult::ItemInvalidCSum => {
846                        // delete the item with an invalid csum
847                        // table.delete_at(item.key, i);
848                        items.remove(i);
849                    }
850                    CoalesceResult::PktInvalidCSum => {
851                        // no point in inserting an item that we can't coalesce
852                        return GroResult::Noop;
853                    }
854                    _ => {}
855                }
856            }
857        }
858    }
859
860    // failed to coalesce with any other packets; store the item in the flow
861    table.insert(
862        pkt,
863        src_addr_offset,
864        src_addr_offset + addr_len,
865        iph_len,
866        tcph_len,
867        pkt_i,
868    );
869    GroResult::TableInsert
870}
871
872/// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
873/// metadata found in table.
874pub fn apply_tcp_coalesce_accounting<B: ExpandBuffer>(
875    bufs: &mut [B],
876    offset: usize,
877    table: &TcpGROTable,
878) -> io::Result<()> {
879    for items in table.items_by_flow.values() {
880        for item in items {
881            if item.num_merged > 0 {
882                let mut hdr = VirtioNetHdr {
883                    flags: VIRTIO_NET_HDR_F_NEEDS_CSUM,
884                    hdr_len: (item.iph_len + item.tcph_len) as u16,
885                    gso_size: item.gso_size,
886                    csum_start: item.iph_len as u16,
887                    csum_offset: 16,
888                    gso_type: 0, // Will be set later
889                };
890                let buf = bufs[item.bufs_index as usize].as_mut();
891                let pkt = &mut buf[offset..];
892                let pkt_len = pkt.len();
893
894                // Calculate the pseudo header checksum and place it at the TCP
895                // checksum offset. Downstream checksum offloading will combine
896                // this with computation of the tcp header and payload checksum.
897                let addr_len = if item.key.is_v6 { 16 } else { 4 };
898                let src_addr_at = if item.key.is_v6 {
899                    IPV6_SRC_ADDR_OFFSET
900                } else {
901                    IPV4_SRC_ADDR_OFFSET
902                };
903
904                let src_addr =
905                    unsafe { &*(&pkt[src_addr_at..src_addr_at + addr_len] as *const [u8]) };
906                let dst_addr = unsafe {
907                    &*(&pkt[src_addr_at + addr_len..src_addr_at + addr_len * 2] as *const [u8])
908                };
909                // Recalculate the total len (IPv4) or payload len (IPv6).
910                // Recalculate the (IPv4) header checksum.
911                if item.key.is_v6 {
912                    hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV6;
913                    BigEndian::write_u16(&mut pkt[4..6], pkt_len as u16 - item.iph_len as u16);
914                } else {
915                    hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV4;
916                    pkt[10] = 0;
917                    pkt[11] = 0;
918                    BigEndian::write_u16(&mut pkt[2..4], pkt_len as u16);
919                    let iph_csum = !checksum(&pkt[..item.iph_len as usize], 0);
920                    BigEndian::write_u16(&mut pkt[10..12], iph_csum);
921                }
922
923                hdr.encode(&mut buf[offset - VIRTIO_NET_HDR_LEN..])?;
924
925                let pkt = &mut buf[offset..];
926
927                let psum = pseudo_header_checksum_no_fold(
928                    IPPROTO_TCP as _,
929                    src_addr,
930                    dst_addr,
931                    pkt_len as u16 - item.iph_len as u16,
932                );
933                let tcp_csum = checksum(&[], psum);
934                BigEndian::write_u16(
935                    &mut pkt[(hdr.csum_start + hdr.csum_offset) as usize..],
936                    tcp_csum,
937                );
938            } else {
939                let hdr = VirtioNetHdr::default();
940                hdr.encode(
941                    &mut bufs[item.bufs_index as usize].as_mut()[offset - VIRTIO_NET_HDR_LEN..],
942                )?;
943            }
944        }
945    }
946    Ok(())
947}
948
949// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
950// metadata found in table.
951pub fn apply_udp_coalesce_accounting<B: ExpandBuffer>(
952    bufs: &mut [B],
953    offset: usize,
954    table: &UdpGROTable,
955) -> io::Result<()> {
956    for items in table.items_by_flow.values() {
957        for item in items {
958            if item.num_merged > 0 {
959                let hdr = VirtioNetHdr {
960                    flags: VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
961                    hdr_len: item.iph_len as u16 + UDP_H_LEN as u16,
962                    gso_size: item.gso_size,
963                    csum_start: item.iph_len as u16,
964                    csum_offset: 6,
965                    gso_type: VIRTIO_NET_HDR_GSO_UDP_L4,
966                };
967
968                let buf = bufs[item.bufs_index as usize].as_mut();
969                let pkt = &mut buf[offset..];
970                let pkt_len = pkt.len();
971
972                // Calculate the pseudo header checksum and place it at the UDP
973                // checksum offset. Downstream checksum offloading will combine
974                // this with computation of the udp header and payload checksum.
975                let (addr_len, src_addr_at) = if item.key.is_v6 {
976                    (16, IPV6_SRC_ADDR_OFFSET)
977                } else {
978                    (4, IPV4_SRC_ADDR_OFFSET)
979                };
980
981                let src_addr =
982                    unsafe { &*(&pkt[src_addr_at..(src_addr_at + addr_len)] as *const [u8]) };
983                let dst_addr = unsafe {
984                    &*(&pkt[(src_addr_at + addr_len)..(src_addr_at + addr_len * 2)]
985                        as *const [u8])
986                };
987
988                // Recalculate the total len (IPv4) or payload len (IPv6).
989                // Recalculate the (IPv4) header checksum.
990                if item.key.is_v6 {
991                    BigEndian::write_u16(&mut pkt[4..6], pkt_len as u16 - item.iph_len as u16);
992                    // set new IPv6 header payload len
993                } else {
994                    pkt[10] = 0;
995                    pkt[11] = 0;
996                    BigEndian::write_u16(&mut pkt[2..4], pkt_len as u16); // set new total length
997                    let iph_csum = !checksum(&pkt[..item.iph_len as usize], 0);
998                    BigEndian::write_u16(&mut pkt[10..12], iph_csum); // set IPv4 header checksum field
999                }
1000
1001                hdr.encode(&mut buf[offset - VIRTIO_NET_HDR_LEN..])?;
1002                let pkt = &mut buf[offset..];
1003                // Recalculate the UDP len field value
1004                BigEndian::write_u16(
1005                    &mut pkt[(item.iph_len as usize + 4)..(item.iph_len as usize + 6)],
1006                    pkt_len as u16 - item.iph_len as u16,
1007                );
1008
1009                let psum = pseudo_header_checksum_no_fold(
1010                    IPPROTO_UDP as _,
1011                    src_addr,
1012                    dst_addr,
1013                    pkt_len as u16 - item.iph_len as u16,
1014                );
1015
1016                let udp_csum = checksum(&[], psum);
1017                BigEndian::write_u16(
1018                    &mut pkt[(hdr.csum_start + hdr.csum_offset) as usize..],
1019                    udp_csum,
1020                );
1021            } else {
1022                let hdr = VirtioNetHdr::default();
1023                hdr.encode(
1024                    &mut bufs[item.bufs_index as usize].as_mut()[offset - VIRTIO_NET_HDR_LEN..],
1025                )?;
1026            }
1027        }
1028    }
1029    Ok(())
1030}
1031
1032#[derive(PartialEq, Eq)]
1033pub enum GroCandidateType {
1034    NotGRO,
1035    Tcp4GRO,
1036    Tcp6GRO,
1037    Udp4GRO,
1038    Udp6GRO,
1039}
1040
1041pub fn packet_is_gro_candidate(b: &[u8], can_udp_gro: bool) -> GroCandidateType {
1042    if b.len() < 28 {
1043        return GroCandidateType::NotGRO;
1044    }
1045    if b[0] >> 4 == 4 {
1046        if b[0] & 0x0F != 5 {
1047            // IPv4 packets w/IP options do not coalesce
1048            return GroCandidateType::NotGRO;
1049        }
1050        match b[9] {
1051            6 if b.len() >= 40 => return GroCandidateType::Tcp4GRO,
1052            17 if can_udp_gro => return GroCandidateType::Udp4GRO,
1053            _ => {}
1054        }
1055    } else if b[0] >> 4 == 6 {
1056        match b[6] {
1057            6 if b.len() >= 60 => return GroCandidateType::Tcp6GRO,
1058            17 if b.len() >= 48 && can_udp_gro => return GroCandidateType::Udp6GRO,
1059            _ => {}
1060        }
1061    }
1062    GroCandidateType::NotGRO
1063}
1064
1065const UDP_H_LEN: usize = 8;
1066
1067/// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
1068/// existing packets tracked in table. It returns a groResultNoop when no
1069/// action was taken, groResultTableInsert when the evaluated packet was
1070/// inserted into table, and groResultCoalesced when the evaluated packet was
1071/// coalesced with another packet in table.
1072fn udp_gro<B: ExpandBuffer>(
1073    bufs: &mut [B],
1074    offset: usize,
1075    pkt_i: usize,
1076    table: &mut UdpGROTable,
1077    is_v6: bool,
1078) -> GroResult {
1079    let pkt = unsafe { &*(&bufs[pkt_i].as_ref()[offset..] as *const [u8]) };
1080    if pkt.len() > u16::MAX as usize {
1081        // A valid IPv4 or IPv6 packet will never exceed this.
1082        return GroResult::Noop;
1083    }
1084
1085    let mut iph_len = ((pkt[0] & 0x0F) * 4) as usize;
1086    if is_v6 {
1087        iph_len = 40;
1088        let ipv6_payload_len = u16::from_be_bytes([pkt[4], pkt[5]]) as usize;
1089        if ipv6_payload_len != pkt.len() - iph_len {
1090            return GroResult::Noop;
1091        }
1092    } else {
1093        let total_len = u16::from_be_bytes([pkt[2], pkt[3]]) as usize;
1094        if total_len != pkt.len() {
1095            return GroResult::Noop;
1096        }
1097    }
1098
1099    if pkt.len() < iph_len || pkt.len() < iph_len + UDP_H_LEN {
1100        return GroResult::Noop;
1101    }
1102
1103    if !is_v6 && (pkt[6] & IPV4_FLAG_MORE_FRAGMENTS != 0 || pkt[6] << 3 != 0 || pkt[7] != 0) {
1104        // No GRO support for fragmented segments for now.
1105        return GroResult::Noop;
1106    }
1107
1108    let gso_size = (pkt.len() - UDP_H_LEN - iph_len) as u16;
1109    if gso_size < 1 {
1110        return GroResult::Noop;
1111    }
1112
1113    let (src_addr_offset, addr_len) = if is_v6 {
1114        (IPV6_SRC_ADDR_OFFSET, 16)
1115    } else {
1116        (IPV4_SRC_ADDR_OFFSET, 4)
1117    };
1118
1119    let items = table.lookup_or_insert(
1120        pkt,
1121        src_addr_offset,
1122        src_addr_offset + addr_len,
1123        iph_len,
1124        pkt_i,
1125    );
1126
1127    let items = if let Some(items) = items {
1128        items
1129    } else {
1130        return GroResult::TableInsert;
1131    };
1132
1133    // Only check the last item to prevent reordering packets for a flow.
1134    let items_len = items.len();
1135    let item = &mut items[items_len - 1];
1136    let can = udp_packets_can_coalesce(pkt, iph_len as u8, gso_size, item, bufs, offset);
1137    let mut pkt_csum_known_invalid = false;
1138
1139    if can == CanCoalesce::Append {
1140        match coalesce_udp_packets(pkt, item, bufs, offset, is_v6) {
1141            CoalesceResult::Success => {
1142                // 前面是引用,这里不需要再更新
1143                // table.update_at(*item, items_len - 1);
1144                return GroResult::Coalesced;
1145            }
1146            CoalesceResult::ItemInvalidCSum => {
1147                // If the existing item has an invalid checksum, take no action.
1148                // A new item will be stored, and the existing item won't be revisited.
1149            }
1150            CoalesceResult::PktInvalidCSum => {
1151                // Insert a new item but mark it with invalid checksum to avoid repeat checks.
1152                pkt_csum_known_invalid = true;
1153            }
1154            _ => {}
1155        }
1156    }
1157    let pkt = &bufs[pkt_i].as_ref()[offset..];
1158    // Failed to coalesce; store the packet in the flow.
1159    table.insert(
1160        pkt,
1161        src_addr_offset,
1162        src_addr_offset + addr_len,
1163        iph_len,
1164        pkt_i,
1165        pkt_csum_known_invalid,
1166    );
1167    GroResult::TableInsert
1168}
1169
1170/// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
1171/// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
1172/// empty (but non-nil), and are passed in to save allocs as the caller may reset
1173/// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
1174/// supported.
1175pub fn handle_gro<B: ExpandBuffer>(
1176    bufs: &mut [B],
1177    offset: usize,
1178    tcp_table: &mut TcpGROTable,
1179    udp_table: &mut UdpGROTable,
1180    can_udp_gro: bool,
1181    to_write: &mut Vec<usize>,
1182) -> io::Result<()> {
1183    let bufs_len = bufs.len();
1184    for i in 0..bufs_len {
1185        if offset < VIRTIO_NET_HDR_LEN || offset > bufs[i].as_ref().len() - 1 {
1186            return Err(io::Error::new(
1187                io::ErrorKind::InvalidInput,
1188                "invalid offset",
1189            ));
1190        }
1191
1192        let result = match packet_is_gro_candidate(&bufs[i].as_ref()[offset..], can_udp_gro) {
1193            GroCandidateType::Tcp4GRO => tcp_gro(bufs, offset, i, tcp_table, false),
1194            GroCandidateType::Tcp6GRO => tcp_gro(bufs, offset, i, tcp_table, true),
1195            GroCandidateType::Udp4GRO => udp_gro(bufs, offset, i, udp_table, false),
1196            GroCandidateType::Udp6GRO => udp_gro(bufs, offset, i, udp_table, true),
1197            GroCandidateType::NotGRO => GroResult::Noop,
1198        };
1199
1200        match result {
1201            GroResult::Noop => {
1202                let hdr = VirtioNetHdr::default();
1203                hdr.encode(&mut bufs[i].as_mut()[offset - VIRTIO_NET_HDR_LEN..offset])?;
1204                // Fallthrough intended
1205                to_write.push(i);
1206            }
1207            GroResult::TableInsert => {
1208                to_write.push(i);
1209            }
1210            _ => {}
1211        }
1212    }
1213
1214    let err_tcp = apply_tcp_coalesce_accounting(bufs, offset, tcp_table);
1215    let err_udp = apply_udp_coalesce_accounting(bufs, offset, udp_table);
1216    err_tcp?;
1217    err_udp?;
1218    Ok(())
1219}
1220
1221/// gsoSplit splits packets from in into outBuffs, writing the size of each
1222/// element into sizes. It returns the number of buffers populated, and/or an
1223/// error.
1224pub fn gso_split<B: AsRef<[u8]> + AsMut<[u8]>>(
1225    input: &mut [u8],
1226    hdr: VirtioNetHdr,
1227    out_bufs: &mut [B],
1228    sizes: &mut [usize],
1229    out_offset: usize,
1230    is_v6: bool,
1231) -> io::Result<usize> {
1232    let iph_len = hdr.csum_start as usize;
1233    let (src_addr_offset, addr_len) = if is_v6 {
1234        (IPV6_SRC_ADDR_OFFSET, 16)
1235    } else {
1236        input[10] = 0;
1237        input[11] = 0; // clear IPv4 header checksum
1238        (IPV4_SRC_ADDR_OFFSET, 4)
1239    };
1240
1241    let transport_csum_at = (hdr.csum_start + hdr.csum_offset) as usize;
1242    input[transport_csum_at] = 0;
1243    input[transport_csum_at + 1] = 0; // clear TCP/UDP checksum
1244
1245    let (first_tcp_seq_num, protocol) =
1246        if hdr.gso_type == VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gso_type == VIRTIO_NET_HDR_GSO_TCPV6 {
1247            (
1248                BigEndian::read_u32(&input[hdr.csum_start as usize + 4..]),
1249                IPPROTO_TCP,
1250            )
1251        } else {
1252            (0, IPPROTO_UDP)
1253        };
1254
1255    let src_addr_bytes = &input[src_addr_offset..src_addr_offset + addr_len];
1256    let dst_addr_bytes = &input[src_addr_offset + addr_len..src_addr_offset + 2 * addr_len];
1257    let transport_header_len = (hdr.hdr_len - hdr.csum_start) as usize;
1258
1259    let nonlast_segment_data_len = hdr.gso_size as usize;
1260    let nonlast_len_for_pseudo = (transport_header_len + nonlast_segment_data_len) as u16;
1261    let nonlast_total_len = hdr.hdr_len as usize + nonlast_segment_data_len;
1262
1263    let nonlast_transport_csum_no_fold = pseudo_header_checksum_no_fold(
1264        protocol as u8,
1265        src_addr_bytes,
1266        dst_addr_bytes,
1267        nonlast_len_for_pseudo,
1268    );
1269
1270    let mut next_segment_data_at = hdr.hdr_len as usize;
1271    let mut i = 0;
1272
1273    while next_segment_data_at < input.len() {
1274        if i == out_bufs.len() {
1275            return Err(io::Error::other("ErrTooManySegments"));
1276        }
1277
1278        let next_segment_end = next_segment_data_at + hdr.gso_size as usize;
1279        let (next_segment_end, segment_data_len, total_len, transport_csum_no_fold) =
1280            if next_segment_end > input.len() {
1281                let last_segment_data_len = input.len() - next_segment_data_at;
1282                let last_len_for_pseudo = (transport_header_len + last_segment_data_len) as u16;
1283
1284                let last_total_len = hdr.hdr_len as usize + last_segment_data_len;
1285                let last_transport_csum_no_fold = pseudo_header_checksum_no_fold(
1286                    protocol as u8,
1287                    src_addr_bytes,
1288                    dst_addr_bytes,
1289                    last_len_for_pseudo,
1290                );
1291                (
1292                    input.len(),
1293                    last_segment_data_len,
1294                    last_total_len,
1295                    last_transport_csum_no_fold,
1296                )
1297            } else {
1298                (
1299                    next_segment_end,
1300                    hdr.gso_size as usize,
1301                    nonlast_total_len,
1302                    nonlast_transport_csum_no_fold,
1303                )
1304            };
1305
1306        sizes[i] = total_len;
1307        let out = &mut out_bufs[i].as_mut()[out_offset..];
1308
1309        out[..iph_len].copy_from_slice(&input[..iph_len]);
1310
1311        if !is_v6 {
1312            // For IPv4 we are responsible for incrementing the ID field,
1313            // updating the total len field, and recalculating the header
1314            // checksum.
1315            if i > 0 {
1316                let id = BigEndian::read_u16(&out[4..]).wrapping_add(i as u16);
1317                BigEndian::write_u16(&mut out[4..6], id);
1318            }
1319            BigEndian::write_u16(&mut out[2..4], total_len as u16);
1320            let ipv4_csum = !checksum(&out[..iph_len], 0);
1321            BigEndian::write_u16(&mut out[10..12], ipv4_csum);
1322        } else {
1323            // For IPv6 we are responsible for updating the payload length field.
1324            // IPv6 extensions are not checksumed, but included in the payload length.
1325            const IPV6_FIXED_HDR_LEN: usize = 40;
1326            let payload_len = total_len - IPV6_FIXED_HDR_LEN;
1327            BigEndian::write_u16(&mut out[4..6], payload_len as u16);
1328        }
1329
1330        out[hdr.csum_start as usize..hdr.hdr_len as usize]
1331            .copy_from_slice(&input[hdr.csum_start as usize..hdr.hdr_len as usize]);
1332
1333        if protocol == IPPROTO_TCP {
1334            let tcp_seq = first_tcp_seq_num.wrapping_add(hdr.gso_size as u32 * i as u32);
1335            BigEndian::write_u32(
1336                &mut out[(hdr.csum_start + 4) as usize..(hdr.csum_start + 8) as usize],
1337                tcp_seq,
1338            );
1339            if next_segment_end != input.len() {
1340                out[hdr.csum_start as usize + TCP_FLAGS_OFFSET] &= !(TCP_FLAG_FIN | TCP_FLAG_PSH);
1341            }
1342        } else {
1343            let udp_len = (segment_data_len + (hdr.hdr_len - hdr.csum_start) as usize) as u16;
1344            BigEndian::write_u16(
1345                &mut out[(hdr.csum_start + 4) as usize..(hdr.csum_start + 6) as usize],
1346                udp_len,
1347            );
1348        }
1349
1350        out[hdr.hdr_len as usize..total_len]
1351            .as_mut()
1352            .copy_from_slice(&input[next_segment_data_at..next_segment_end]);
1353
1354        let transport_csum = !checksum(
1355            &out[hdr.csum_start as usize..total_len],
1356            transport_csum_no_fold,
1357        );
1358        BigEndian::write_u16(
1359            &mut out[transport_csum_at..transport_csum_at + 2],
1360            transport_csum,
1361        );
1362
1363        next_segment_data_at += hdr.gso_size as usize;
1364        i += 1;
1365    }
1366
1367    Ok(i)
1368}
1369
1370pub fn gso_none_checksum(in_buf: &mut [u8], csum_start: u16, csum_offset: u16) {
1371    let csum_at = (csum_start + csum_offset) as usize;
1372    // The initial value at the checksum offset should be summed with the
1373    // checksum we compute. This is typically the pseudo-header checksum.
1374    let initial = BigEndian::read_u16(&in_buf[csum_at..]);
1375    in_buf[csum_at] = 0;
1376    in_buf[csum_at + 1] = 0;
1377    let computed_checksum = checksum(&in_buf[csum_start as usize..], initial as u64);
1378    BigEndian::write_u16(&mut in_buf[csum_at..], !computed_checksum);
1379}
1380
1381/// `send_multiple` Using GROTable to assist in writing data
1382#[derive(Default)]
1383pub struct GROTable {
1384    pub(crate) to_write: Vec<usize>,
1385    pub(crate) tcp_gro_table: TcpGROTable,
1386    pub(crate) udp_gro_table: UdpGROTable,
1387}
1388
1389impl GROTable {
1390    pub fn new() -> GROTable {
1391        GROTable {
1392            to_write: Vec::with_capacity(IDEAL_BATCH_SIZE),
1393            tcp_gro_table: TcpGROTable::new(),
1394            udp_gro_table: UdpGROTable::new(),
1395        }
1396    }
1397    pub(crate) fn reset(&mut self) {
1398        self.to_write.clear();
1399        self.tcp_gro_table.reset();
1400        self.udp_gro_table.reset();
1401    }
1402}
1403
1404pub trait ExpandBuffer: AsRef<[u8]> + AsMut<[u8]> {
1405    fn buf_capacity(&self) -> usize;
1406    fn buf_resize(&mut self, new_len: usize, value: u8);
1407    fn buf_extend_from_slice(&mut self, src: &[u8]);
1408}
1409
1410impl ExpandBuffer for BytesMut {
1411    fn buf_capacity(&self) -> usize {
1412        self.capacity()
1413    }
1414
1415    fn buf_resize(&mut self, new_len: usize, value: u8) {
1416        self.resize(new_len, value)
1417    }
1418
1419    fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1420        self.extend_from_slice(extend)
1421    }
1422}
1423
1424impl ExpandBuffer for &mut BytesMut {
1425    fn buf_capacity(&self) -> usize {
1426        self.capacity()
1427    }
1428    fn buf_resize(&mut self, new_len: usize, value: u8) {
1429        self.resize(new_len, value)
1430    }
1431
1432    fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1433        self.extend_from_slice(extend)
1434    }
1435}
1436impl ExpandBuffer for Vec<u8> {
1437    fn buf_capacity(&self) -> usize {
1438        self.capacity()
1439    }
1440
1441    fn buf_resize(&mut self, new_len: usize, value: u8) {
1442        self.resize(new_len, value)
1443    }
1444
1445    fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1446        self.extend_from_slice(extend)
1447    }
1448}
1449impl ExpandBuffer for &mut Vec<u8> {
1450    fn buf_capacity(&self) -> usize {
1451        self.capacity()
1452    }
1453
1454    fn buf_resize(&mut self, new_len: usize, value: u8) {
1455        self.resize(new_len, value)
1456    }
1457
1458    fn buf_extend_from_slice(&mut self, extend: &[u8]) {
1459        self.extend_from_slice(extend)
1460    }
1461}