use crate::ip::IpNextLevelProtocol;
use crate::Packet;
use alloc::vec::Vec;
use xenet_macro::packet;
use xenet_macro_helper::types::*;
use crate::util;
use std::net::{Ipv4Addr, Ipv6Addr};
pub const UDP_HEADER_LEN: usize = 8;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct UdpHeader {
pub source: u16be,
pub destination: u16be,
pub length: u16be,
pub checksum: u16be,
}
impl UdpHeader {
pub fn from_bytes(packet: &[u8]) -> Result<UdpHeader, String> {
if packet.len() < UDP_HEADER_LEN {
return Err("Packet is too small for UDP header".to_string());
}
match UdpPacket::new(packet) {
Some(udp_packet) => Ok(UdpHeader {
source: udp_packet.get_source(),
destination: udp_packet.get_destination(),
length: udp_packet.get_length(),
checksum: udp_packet.get_checksum(),
}),
None => Err("Failed to parse UDP packet".to_string()),
}
}
pub(crate) fn from_packet(udp_packet: &UdpPacket) -> UdpHeader {
UdpHeader {
source: udp_packet.get_source(),
destination: udp_packet.get_destination(),
length: udp_packet.get_length(),
checksum: udp_packet.get_checksum(),
}
}
}
#[packet]
pub struct Udp {
pub source: u16be,
pub destination: u16be,
pub length: u16be,
pub checksum: u16be,
#[payload]
pub payload: Vec<u8>,
}
pub fn ipv4_checksum(packet: &UdpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16be {
ipv4_checksum_adv(packet, &[], source, destination)
}
pub fn ipv4_checksum_adv(
packet: &UdpPacket,
extra_data: &[u8],
source: &Ipv4Addr,
destination: &Ipv4Addr,
) -> u16be {
util::ipv4_checksum(
packet.packet(),
3,
extra_data,
source,
destination,
IpNextLevelProtocol::Udp,
)
}
#[test]
fn udp_header_ipv4_test() {
use crate::ip::IpNextLevelProtocol;
use crate::ipv4::MutableIpv4Packet;
let mut packet = [0u8; 20 + 8 + 4];
let ipv4_source = Ipv4Addr::new(192, 168, 0, 1);
let ipv4_destination = Ipv4Addr::new(192, 168, 0, 199);
{
let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap();
ip_header.set_next_level_protocol(IpNextLevelProtocol::Udp);
ip_header.set_source(ipv4_source);
ip_header.set_destination(ipv4_destination);
}
packet[20 + 8] = 't' as u8;
packet[20 + 8 + 1] = 'e' as u8;
packet[20 + 8 + 2] = 's' as u8;
packet[20 + 8 + 3] = 't' as u8;
{
let mut udp_header = MutableUdpPacket::new(&mut packet[20..]).unwrap();
udp_header.set_source(12345);
assert_eq!(udp_header.get_source(), 12345);
udp_header.set_destination(54321);
assert_eq!(udp_header.get_destination(), 54321);
udp_header.set_length(8 + 4);
assert_eq!(udp_header.get_length(), 8 + 4);
let checksum = ipv4_checksum(&udp_header.to_immutable(), &ipv4_source, &ipv4_destination);
udp_header.set_checksum(checksum);
assert_eq!(udp_header.get_checksum(), 0x9178);
}
let ref_packet = [
0x30, 0x39, 0xd4, 0x31, 0x00, 0x0c, 0x91, 0x78, ];
assert_eq!(&ref_packet[..], &packet[20..28]);
}
pub fn ipv6_checksum(packet: &UdpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16be {
ipv6_checksum_adv(packet, &[], source, destination)
}
pub fn ipv6_checksum_adv(
packet: &UdpPacket,
extra_data: &[u8],
source: &Ipv6Addr,
destination: &Ipv6Addr,
) -> u16be {
util::ipv6_checksum(
packet.packet(),
3,
extra_data,
source,
destination,
IpNextLevelProtocol::Udp,
)
}
#[test]
fn udp_header_ipv6_test() {
use crate::ip::IpNextLevelProtocol;
use crate::ipv6::MutableIpv6Packet;
let mut packet = [0u8; 40 + 8 + 4];
let ipv6_source = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
let ipv6_destination = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
{
let mut ip_header = MutableIpv6Packet::new(&mut packet[..]).unwrap();
ip_header.set_next_header(IpNextLevelProtocol::Udp);
ip_header.set_source(ipv6_source);
ip_header.set_destination(ipv6_destination);
}
packet[40 + 8] = 't' as u8;
packet[40 + 8 + 1] = 'e' as u8;
packet[40 + 8 + 2] = 's' as u8;
packet[40 + 8 + 3] = 't' as u8;
{
let mut udp_header = MutableUdpPacket::new(&mut packet[40..]).unwrap();
udp_header.set_source(12345);
assert_eq!(udp_header.get_source(), 12345);
udp_header.set_destination(54321);
assert_eq!(udp_header.get_destination(), 54321);
udp_header.set_length(8 + 4);
assert_eq!(udp_header.get_length(), 8 + 4);
let checksum = ipv6_checksum(&udp_header.to_immutable(), &ipv6_source, &ipv6_destination);
udp_header.set_checksum(checksum);
assert_eq!(udp_header.get_checksum(), 0x1390);
}
let ref_packet = [
0x30, 0x39, 0xd4, 0x31, 0x00, 0x0c, 0x13, 0x90, ];
assert_eq!(&ref_packet[..], &packet[40..48]);
}