sparkles_parser/
packet_decoder.rs

1use std::{io, thread};
2use std::io::{BufRead, ErrorKind, Read};
3use std::net::{ToSocketAddrs, UdpSocket};
4use std::sync::atomic::Ordering;
5use std::time::{Duration, Instant};
6use enumset::EnumSet;
7use log::{debug, info, warn};
8use thiserror::Error;
9use sparkles_core::protocol::headers::{LocalPacketHeader, SparklesMachineInfo};
10use sparkles_core::protocol::packets::{PacketType, RequestPacketType};
11use sparkles_core::protocol::sender::PacketFlags;
12use crate::SHUTDOWN_SIGNAL;
13
14pub enum Packet {
15    MachineInfo(SparklesMachineInfo),
16    DataBytes(Vec<(LocalPacketHeader, Vec<u8>)>),
17    FailedPages(Vec<LocalPacketHeader>),
18    TimestampFreq(u64, u64),
19    GracefulShutdown,
20    ConnectionAccepted,
21    Hello,
22}
23
24#[derive(Copy, Clone, Debug, Default)]
25pub struct ProtocolCounters {
26    pub protocol_overhead: usize,
27    pub trace_buf: usize,
28    pub secondary_packets: usize,
29}
30
31impl ProtocolCounters {
32    pub fn total_bytes(&self) -> usize {
33        self.protocol_overhead + self.trace_buf + self.secondary_packets
34    }
35}
36
37pub enum PacketDecoder {
38    Stream{
39        stream: Box<dyn BufRead + Send>,
40        is_eof: bool,
41        counters: ProtocolCounters,
42    },
43    Socket{
44        socket: UdpSocket,
45        is_eof: bool,
46        counters: ProtocolCounters,
47        
48        last_seq_num: u16,
49
50        partial_packet_info: Option<UdpParserState>,
51        
52        read_buffer: Vec<u8>,
53        last_recv_time: Option<Instant>,
54    }
55}
56
57pub struct UdpParserState {
58    received_chunks: Vec<(usize, Vec<u8>)>,
59    starting_num: usize,
60}
61
62impl UdpParserState {
63    pub fn new(chunk_num: u8, data: &[u8]) -> Self {
64        UdpParserState {
65            received_chunks: vec![(chunk_num as usize, data.to_vec())],
66            starting_num: 0,
67        }
68    }
69    pub fn push(&mut self, chunk_num: u8, chunk: Vec<u8>) -> bool {
70        if chunk_num == 0 && !self.received_chunks.is_empty() {
71            self.starting_num += 256;
72        }
73        
74        let chunk_num = chunk_num as usize + self.starting_num;
75        if self.received_chunks.iter().any(|(num, _)| *num == chunk_num) {
76            return false;
77        }
78        self.received_chunks.push((chunk_num, chunk));
79        true
80    }
81    pub fn build(&mut self) -> Option<Vec<u8>> {
82        let max_chunk_num = *self.received_chunks.iter().map(|(num, _)| num).max().unwrap();
83        if max_chunk_num + 1 != self.received_chunks.len() {
84            return None;
85        }
86
87        let mut res = Vec::new();
88        self.received_chunks.sort_by_key(|(num, _)| *num);
89        for (_, chunk) in &self.received_chunks {
90            res.extend_from_slice(chunk);
91        }
92        Some(res)
93    }
94}
95
96
97pub type ReadResult<T> = Result<T, PacketReadError>;
98#[derive(Error, Debug)]
99pub enum PacketReadError {
100    #[error("IO error: {0}")]
101    IoError(#[from] std::io::Error),
102    #[error("Decode error: {0}")]
103    DecodeError(#[from] bincode::error::DecodeError),
104
105    #[error("End of file reached")]
106    Eof,
107    #[error("Packet length too big")]
108    LengthTooBig,
109
110
111    // Udp errors
112    #[error("Udp packet too short")]
113    UdpPacketTooShort,
114    #[error("Incorrect packet type pattern")]
115    IncorrectPattern,
116    #[error("The same seq num was received twice!")]
117    RepeatedSeqNum,
118    #[error("Incomplete long packet")]
119    IncompleteLongPacket,
120}
121impl PacketDecoder {
122    pub fn from_stream(stream: impl Read + Send + 'static) -> Self {
123        let stream = Box::new(io::BufReader::new(stream));
124        PacketDecoder::Stream{
125            stream,
126            is_eof: false,
127            counters: ProtocolCounters::default(),
128        }
129    }
130
131    pub fn from_socket(addr: impl ToSocketAddrs) -> Self {
132        let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
133        socket.connect(addr).unwrap();
134        socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap();
135
136        #[cfg(feature="self-tracing")]
137        let g = sparkles_macro::range_event_start!("Subscribing to events...");
138        loop {
139            socket.send(&RequestPacketType::Subscribe.pattern()).unwrap();
140
141            let mut buf = [0u8; 32];
142            match socket.recv(&mut buf) {
143                Ok(32) if buf == PacketType::ConnectionAccepted.pattern() => {
144                    break;
145                }
146                Ok(_) => {
147                    warn!("Incorrect packet received from server! Ignoring...");
148                    thread::sleep(Duration::from_millis(500));
149                }
150                Err(e) => {
151                    if e.kind() == io::ErrorKind::WouldBlock {
152                        continue;
153                    }
154                    if !matches!(e.kind(), io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionRefused) {
155                        warn!("Error receiving packet from server: {}", e);
156                    }
157                    thread::sleep(Duration::from_millis(500));
158                }
159            }
160            
161            if SHUTDOWN_SIGNAL.load(Ordering::Relaxed) {
162                break;
163            }
164        }
165
166        PacketDecoder::Socket{
167            socket,
168            is_eof: false,
169            counters: ProtocolCounters::default(),
170            partial_packet_info: None,
171            last_seq_num: 0,
172            read_buffer: vec![0; 1400],
173            last_recv_time: None,
174        }
175    }
176
177    pub fn read_packet(&mut self) -> ReadResult<Option<Packet>> {
178        match self {
179            PacketDecoder::Stream {
180                stream, is_eof,
181                counters
182            } => {
183                if *is_eof {
184                    return Err(PacketReadError::Eof);
185                }
186
187                // read packet header and length
188                let mut packet_type_buf = vec![0; 32];
189                stream.read_exact(&mut packet_type_buf)?;
190                counters.protocol_overhead += 32;
191                
192                let mut read_find_packet_start = || -> Result<PacketType, io::Error> {
193                    loop {
194                        if let Some(header) = PacketType::try_from_pattern(&packet_type_buf[..32]) {
195                            return Ok(header);
196                        }
197
198                        // Shift buffer by 1 byte
199                        packet_type_buf.remove(0);
200                        let mut new_byte = [0; 1];
201                        stream.read_exact(&mut new_byte)?;
202                        counters.protocol_overhead += 32;
203                        packet_type_buf.push(new_byte[0]);
204                    }
205                };
206
207                let packet_type = read_find_packet_start();
208                let Ok(packet_type) = packet_type else {
209                    return Err(packet_type.unwrap_err().into());
210                };
211                let mut length_buf = [0; 4];
212                stream.read_exact(&mut length_buf)?;
213                counters.protocol_overhead += 4;
214                let length = u32::from_be_bytes(length_buf);
215                if length < 100_000_000 {
216                    // Parse packet data
217                    let mut data = vec![0; length as usize];
218                    stream.read_exact(&mut data)?;
219
220                    let res = parse_packet_from_data(packet_type, &data)?;
221                    if matches!(res, Packet::GracefulShutdown) {
222                        *is_eof = true;
223                    }
224                    if matches!(res, Packet::DataBytes(_)) {
225                        counters.trace_buf += length as usize;
226                    }
227                    else {
228                        counters.secondary_packets += length as usize;
229                    }
230                    Ok(Some(res))
231                }
232                else {
233                    warn!("[PacketDecoder] Packet size too large! Ignoring...");
234                    Err(PacketReadError::LengthTooBig)
235                }
236            }
237            PacketDecoder::Socket{
238                socket, is_eof,
239                partial_packet_info,
240                last_seq_num,
241                counters,
242                read_buffer,
243                last_recv_time,
244            } => unsafe {
245                if *is_eof {
246                    return Err(PacketReadError::Eof);
247                }
248
249                #[cfg(feature="self-tracing")]
250                let g = sparkles_macro::range_event_start!("Recv packet");
251                let new_packet_sz = socket.recv(read_buffer).inspect_err(move |e| {
252                    if e.kind() == ErrorKind::WouldBlock {
253                        #[cfg(feature="self-tracing")]
254                        sparkles_macro::range_event_end!(g, "Timeout!");
255                    }
256                })?;
257                let recv_time = Instant::now();
258                let dur_since_last_packet = last_recv_time.map(|t| recv_time - t);
259                #[cfg(feature="self-tracing")]
260                let g = sparkles_macro::range_event_start!("Decode packet");
261                
262                let packet = &read_buffer[..new_packet_sz];
263                if new_packet_sz < 32 + 3 {
264                    warn!("[PacketDecoder] Udp packet too short! Ignoring...");
265                    return Err(PacketReadError::UdpPacketTooShort);
266                }
267                let packet_type = &packet[..32];
268                let Some(packet_type) = PacketType::try_from_pattern(packet_type) else {
269                    warn!("[PacketDecoder] Udp packet type not recognized! Ignoring...");
270                    return Err(PacketReadError::IncorrectPattern);
271                };
272                counters.protocol_overhead += 32;
273                let seq_num = u16::from_be_bytes(packet[32..34].try_into().unwrap());
274                counters.protocol_overhead += 2;
275
276
277                if dur_since_last_packet.is_none_or(|d| d > Duration::from_secs(10)) {
278                }
279                else if seq_num == *last_seq_num {
280                    warn!("[PacketDecoder] Udp packet sequence number repeated! Ignoring...");
281                    return Err(PacketReadError::RepeatedSeqNum)
282                }
283                *last_seq_num = seq_num;
284                *last_recv_time = Some(recv_time);
285
286                let seq_num_is_incremented = (*last_seq_num).wrapping_add(1) == seq_num;
287                let lost_packets = (*last_seq_num).wrapping_sub(seq_num).wrapping_sub(1);
288                if !seq_num_is_incremented && lost_packets < u16::MAX - 1_000 {
289                    warn!("[PacketDecoder] We lost {} packets!", seq_num.wrapping_sub(*last_seq_num) - 1);
290                }
291
292                // Handle received packet
293                let flags = EnumSet::from_repr_unchecked(packet[34]);
294                counters.protocol_overhead += 1;
295                let (res, data_len) = if flags.contains(PacketFlags::ShortPacket) {
296                    let data = &packet[35..];
297                    let data_len = data.len();
298
299                    if partial_packet_info.is_some() {
300                        warn!("[PacketDecoder] Resetting partial packet info!");
301                        *partial_packet_info = None;
302                    }
303                    (Some(parse_packet_from_data(packet_type, data)?), data_len)
304                }
305                else {
306                    let chunk_num = packet[35];
307                    let data = &packet[36..];
308                    let data_len = data.len();
309
310
311                    if let Some(udp_state) = partial_packet_info {
312                        // info!("Long packet: chunk {chunk_num}, data_size: {data_len}");
313                        if !udp_state.push(chunk_num, data.to_vec()) {
314                            warn!("[PacketDecoder] Duplicate or incomplete chunks for long packet! Ignoring...");
315                            return Err(PacketReadError::IncompleteLongPacket);
316                        }
317                        if flags.contains(PacketFlags::PacketEnd) {
318                            info!("It was last chunk, building packet...");
319                            #[cfg(feature="self-tracing")]
320                            let g = sparkles_macro::range_event_start!("Building long packet");
321                            #[cfg(feature="self-tracing")]
322                            let g = sparkles_macro::range_event_start!("Assemple packet");
323                            let long_packet_data = udp_state.build();
324                            #[cfg(feature="self-tracing")]
325                            drop(g);
326                            *partial_packet_info = None;
327                            if let Some(data) = long_packet_data {
328                                #[cfg(feature="self-tracing")]
329                                let g = sparkles_macro::range_event_start!("Parse trace packet");
330                                let res = parse_packet_from_data(packet_type, &data)?;
331                                #[cfg(feature="self-tracing")]
332                                drop(g);
333                                (Some(res), data_len)
334                            }
335                            else {
336                                warn!("Some chunks of long packet are missing! Skipping...");
337                                return Err(PacketReadError::IncompleteLongPacket)
338                            }
339                        }
340                        else {
341                            (None, data_len)
342                        }
343                    }
344                    else {
345                        if !flags.contains(PacketFlags::PacketStart) {
346                            warn!("Assertion failed! Udp packet chunk without start flag!");
347                        }
348                        *partial_packet_info = Some(UdpParserState::new(chunk_num, data));
349                        (None, data_len)
350                    }
351                };
352
353
354                if packet_type == PacketType::DataBytes {
355                    counters.trace_buf += data_len;
356                }
357                else {
358                    counters.secondary_packets += data_len;
359
360                    if packet_type == PacketType::GracefulShutdown {
361                        *is_eof = true;
362                    }
363                }
364                Ok(res)
365            }
366        }
367    }
368
369    pub fn is_eof(&self) -> bool {
370        match self {
371            PacketDecoder::Socket{is_eof, ..} |
372            PacketDecoder::Stream{is_eof, ..} => *is_eof
373        }
374    }
375    pub fn counters(&self) -> ProtocolCounters {
376        match self {
377            PacketDecoder::Socket{counters, ..} |
378            PacketDecoder::Stream{counters, ..} => *counters
379        }
380    }
381}
382
383fn parse_packet_from_data(packet_type: PacketType, data: &[u8]) -> ReadResult<Packet> {
384    match packet_type {
385        PacketType::GracefulShutdown => {
386            Ok(Packet::GracefulShutdown)
387        }
388        PacketType::MachineInfo => {
389            info!("Got MachineInfo packet!");
390            let (machine_info, sz) = bincode::decode_from_slice(data, bincode_config())?;
391            if sz != data.len() {
392                warn!("[PacketDecoder] Assertion failed! MachineInfo packet size mismatch!");
393            }
394            Ok(Packet::MachineInfo(machine_info))
395        }
396        PacketType::DataBytes => {
397            let mut res = Vec::new();
398            let mut cursor = 0;
399            loop {
400                if cursor + 8 >= data.len() {
401                    warn!("[PacketDecoder] DataBytes partial data received!");
402                    break;
403                }
404                let length = u64::from_be_bytes(data[cursor..cursor + 8].try_into().unwrap()) as usize;
405                cursor += 8;
406                debug!("local packet header len: {length}");
407
408                if cursor + length >= data.len() {
409                    warn!("[PacketDecoder] DataBytes partial data received!");
410                    break;
411                }
412                let (local_packet_header, sz) = bincode::decode_from_slice(&data[cursor..], bincode_config())?;
413                if sz != length {
414                    warn!("[PacketDecoder] Assertion failed! LocalPacketHeader packet size mismatch!");
415                }
416                cursor += length;
417                
418                debug!("LocalPacketHeader received!");
419
420                if cursor + 8 >= data.len() {
421                    warn!("[PacketDecoder] DataBytes partial data received!");
422                    break;
423                }
424                let buf_len = u64::from_be_bytes(data[cursor..cursor + 8].try_into().unwrap()) as usize;
425                cursor += 8;
426
427                debug!("Data len received: {buf_len}");
428
429                if cursor + buf_len > data.len() {
430                    warn!("[PacketDecoder] DataBytes partial data received!");
431                    break;
432                }
433                let buf = data[cursor..cursor + buf_len].to_vec();
434                res.push((local_packet_header, buf));
435                cursor += buf_len;
436                debug!("Got {}KB of data!", buf_len as f32 / 1024.0);
437                
438                if cursor == data.len() {
439                    break;
440                }
441            }
442            Ok(Packet::DataBytes(res))
443        }
444        PacketType::FailedPages => {
445            let (failed_pages, sz) = bincode::decode_from_slice(data, bincode_config())?;
446            if sz != data.len() {
447                warn!("[PacketDecoder] Assertion failed! FailedPages packet size mismatch!");
448            }
449            Ok(Packet::FailedPages(failed_pages))
450        }
451        PacketType::TimestampFreq => {
452            let freq = u64::from_be_bytes(data[..8].try_into().unwrap());
453            let cur_tm = u64::from_be_bytes(data[8..16].try_into().unwrap());
454            Ok(Packet::TimestampFreq(freq, cur_tm))
455        }
456        PacketType::ConnectionAccepted => {
457            Ok(Packet::ConnectionAccepted)
458        }
459    }
460}
461
462fn bincode_config() -> impl bincode::config::Config {
463    bincode::config::standard().with_limit::<100_000>()
464}