1use std::{io, thread};
2use std::io::{BufRead, ErrorKind, Read};
3use std::net::{ToSocketAddrs, UdpSocket};
4use std::sync::atomic::Ordering;
5use std::time::{Duration, Instant};
6use enumset::EnumSet;
7use log::{debug, info, warn};
8use thiserror::Error;
9use sparkles_core::protocol::headers::{LocalPacketHeader, SparklesMachineInfo};
10use sparkles_core::protocol::packets::{PacketType, RequestPacketType};
11use sparkles_core::protocol::sender::PacketFlags;
12use crate::SHUTDOWN_SIGNAL;
13
14pub enum Packet {
15 MachineInfo(SparklesMachineInfo),
16 DataBytes(Vec<(LocalPacketHeader, Vec<u8>)>),
17 FailedPages(Vec<LocalPacketHeader>),
18 TimestampFreq(u64, u64),
19 GracefulShutdown,
20 ConnectionAccepted,
21 Hello,
22}
23
24#[derive(Copy, Clone, Debug, Default)]
25pub struct ProtocolCounters {
26 pub protocol_overhead: usize,
27 pub trace_buf: usize,
28 pub secondary_packets: usize,
29}
30
31impl ProtocolCounters {
32 pub fn total_bytes(&self) -> usize {
33 self.protocol_overhead + self.trace_buf + self.secondary_packets
34 }
35}
36
37pub enum PacketDecoder {
38 Stream{
39 stream: Box<dyn BufRead + Send>,
40 is_eof: bool,
41 counters: ProtocolCounters,
42 },
43 Socket{
44 socket: UdpSocket,
45 is_eof: bool,
46 counters: ProtocolCounters,
47
48 last_seq_num: u16,
49
50 partial_packet_info: Option<UdpParserState>,
51
52 read_buffer: Vec<u8>,
53 last_recv_time: Option<Instant>,
54 }
55}
56
57pub struct UdpParserState {
58 received_chunks: Vec<(usize, Vec<u8>)>,
59 starting_num: usize,
60}
61
62impl UdpParserState {
63 pub fn new(chunk_num: u8, data: &[u8]) -> Self {
64 UdpParserState {
65 received_chunks: vec![(chunk_num as usize, data.to_vec())],
66 starting_num: 0,
67 }
68 }
69 pub fn push(&mut self, chunk_num: u8, chunk: Vec<u8>) -> bool {
70 if chunk_num == 0 && !self.received_chunks.is_empty() {
71 self.starting_num += 256;
72 }
73
74 let chunk_num = chunk_num as usize + self.starting_num;
75 if self.received_chunks.iter().any(|(num, _)| *num == chunk_num) {
76 return false;
77 }
78 self.received_chunks.push((chunk_num, chunk));
79 true
80 }
81 pub fn build(&mut self) -> Option<Vec<u8>> {
82 let max_chunk_num = *self.received_chunks.iter().map(|(num, _)| num).max().unwrap();
83 if max_chunk_num + 1 != self.received_chunks.len() {
84 return None;
85 }
86
87 let mut res = Vec::new();
88 self.received_chunks.sort_by_key(|(num, _)| *num);
89 for (_, chunk) in &self.received_chunks {
90 res.extend_from_slice(chunk);
91 }
92 Some(res)
93 }
94}
95
96
97pub type ReadResult<T> = Result<T, PacketReadError>;
98#[derive(Error, Debug)]
99pub enum PacketReadError {
100 #[error("IO error: {0}")]
101 IoError(#[from] std::io::Error),
102 #[error("Decode error: {0}")]
103 DecodeError(#[from] bincode::error::DecodeError),
104
105 #[error("End of file reached")]
106 Eof,
107 #[error("Packet length too big")]
108 LengthTooBig,
109
110
111 #[error("Udp packet too short")]
113 UdpPacketTooShort,
114 #[error("Incorrect packet type pattern")]
115 IncorrectPattern,
116 #[error("The same seq num was received twice!")]
117 RepeatedSeqNum,
118 #[error("Incomplete long packet")]
119 IncompleteLongPacket,
120}
121impl PacketDecoder {
122 pub fn from_stream(stream: impl Read + Send + 'static) -> Self {
123 let stream = Box::new(io::BufReader::new(stream));
124 PacketDecoder::Stream{
125 stream,
126 is_eof: false,
127 counters: ProtocolCounters::default(),
128 }
129 }
130
131 pub fn from_socket(addr: impl ToSocketAddrs) -> Self {
132 let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
133 socket.connect(addr).unwrap();
134 socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap();
135
136 #[cfg(feature="self-tracing")]
137 let g = sparkles_macro::range_event_start!("Subscribing to events...");
138 loop {
139 socket.send(&RequestPacketType::Subscribe.pattern()).unwrap();
140
141 let mut buf = [0u8; 32];
142 match socket.recv(&mut buf) {
143 Ok(32) if buf == PacketType::ConnectionAccepted.pattern() => {
144 break;
145 }
146 Ok(_) => {
147 warn!("Incorrect packet received from server! Ignoring...");
148 thread::sleep(Duration::from_millis(500));
149 }
150 Err(e) => {
151 if e.kind() == io::ErrorKind::WouldBlock {
152 continue;
153 }
154 if !matches!(e.kind(), io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionRefused) {
155 warn!("Error receiving packet from server: {}", e);
156 }
157 thread::sleep(Duration::from_millis(500));
158 }
159 }
160
161 if SHUTDOWN_SIGNAL.load(Ordering::Relaxed) {
162 break;
163 }
164 }
165
166 PacketDecoder::Socket{
167 socket,
168 is_eof: false,
169 counters: ProtocolCounters::default(),
170 partial_packet_info: None,
171 last_seq_num: 0,
172 read_buffer: vec![0; 1400],
173 last_recv_time: None,
174 }
175 }
176
177 pub fn read_packet(&mut self) -> ReadResult<Option<Packet>> {
178 match self {
179 PacketDecoder::Stream {
180 stream, is_eof,
181 counters
182 } => {
183 if *is_eof {
184 return Err(PacketReadError::Eof);
185 }
186
187 let mut packet_type_buf = vec![0; 32];
189 stream.read_exact(&mut packet_type_buf)?;
190 counters.protocol_overhead += 32;
191
192 let mut read_find_packet_start = || -> Result<PacketType, io::Error> {
193 loop {
194 if let Some(header) = PacketType::try_from_pattern(&packet_type_buf[..32]) {
195 return Ok(header);
196 }
197
198 packet_type_buf.remove(0);
200 let mut new_byte = [0; 1];
201 stream.read_exact(&mut new_byte)?;
202 counters.protocol_overhead += 32;
203 packet_type_buf.push(new_byte[0]);
204 }
205 };
206
207 let packet_type = read_find_packet_start();
208 let Ok(packet_type) = packet_type else {
209 return Err(packet_type.unwrap_err().into());
210 };
211 let mut length_buf = [0; 4];
212 stream.read_exact(&mut length_buf)?;
213 counters.protocol_overhead += 4;
214 let length = u32::from_be_bytes(length_buf);
215 if length < 100_000_000 {
216 let mut data = vec![0; length as usize];
218 stream.read_exact(&mut data)?;
219
220 let res = parse_packet_from_data(packet_type, &data)?;
221 if matches!(res, Packet::GracefulShutdown) {
222 *is_eof = true;
223 }
224 if matches!(res, Packet::DataBytes(_)) {
225 counters.trace_buf += length as usize;
226 }
227 else {
228 counters.secondary_packets += length as usize;
229 }
230 Ok(Some(res))
231 }
232 else {
233 warn!("[PacketDecoder] Packet size too large! Ignoring...");
234 Err(PacketReadError::LengthTooBig)
235 }
236 }
237 PacketDecoder::Socket{
238 socket, is_eof,
239 partial_packet_info,
240 last_seq_num,
241 counters,
242 read_buffer,
243 last_recv_time,
244 } => unsafe {
245 if *is_eof {
246 return Err(PacketReadError::Eof);
247 }
248
249 #[cfg(feature="self-tracing")]
250 let g = sparkles_macro::range_event_start!("Recv packet");
251 let new_packet_sz = socket.recv(read_buffer).inspect_err(move |e| {
252 if e.kind() == ErrorKind::WouldBlock {
253 #[cfg(feature="self-tracing")]
254 sparkles_macro::range_event_end!(g, "Timeout!");
255 }
256 })?;
257 let recv_time = Instant::now();
258 let dur_since_last_packet = last_recv_time.map(|t| recv_time - t);
259 #[cfg(feature="self-tracing")]
260 let g = sparkles_macro::range_event_start!("Decode packet");
261
262 let packet = &read_buffer[..new_packet_sz];
263 if new_packet_sz < 32 + 3 {
264 warn!("[PacketDecoder] Udp packet too short! Ignoring...");
265 return Err(PacketReadError::UdpPacketTooShort);
266 }
267 let packet_type = &packet[..32];
268 let Some(packet_type) = PacketType::try_from_pattern(packet_type) else {
269 warn!("[PacketDecoder] Udp packet type not recognized! Ignoring...");
270 return Err(PacketReadError::IncorrectPattern);
271 };
272 counters.protocol_overhead += 32;
273 let seq_num = u16::from_be_bytes(packet[32..34].try_into().unwrap());
274 counters.protocol_overhead += 2;
275
276
277 if dur_since_last_packet.is_none_or(|d| d > Duration::from_secs(10)) {
278 }
279 else if seq_num == *last_seq_num {
280 warn!("[PacketDecoder] Udp packet sequence number repeated! Ignoring...");
281 return Err(PacketReadError::RepeatedSeqNum)
282 }
283 *last_seq_num = seq_num;
284 *last_recv_time = Some(recv_time);
285
286 let seq_num_is_incremented = (*last_seq_num).wrapping_add(1) == seq_num;
287 let lost_packets = (*last_seq_num).wrapping_sub(seq_num).wrapping_sub(1);
288 if !seq_num_is_incremented && lost_packets < u16::MAX - 1_000 {
289 warn!("[PacketDecoder] We lost {} packets!", seq_num.wrapping_sub(*last_seq_num) - 1);
290 }
291
292 let flags = EnumSet::from_repr_unchecked(packet[34]);
294 counters.protocol_overhead += 1;
295 let (res, data_len) = if flags.contains(PacketFlags::ShortPacket) {
296 let data = &packet[35..];
297 let data_len = data.len();
298
299 if partial_packet_info.is_some() {
300 warn!("[PacketDecoder] Resetting partial packet info!");
301 *partial_packet_info = None;
302 }
303 (Some(parse_packet_from_data(packet_type, data)?), data_len)
304 }
305 else {
306 let chunk_num = packet[35];
307 let data = &packet[36..];
308 let data_len = data.len();
309
310
311 if let Some(udp_state) = partial_packet_info {
312 if !udp_state.push(chunk_num, data.to_vec()) {
314 warn!("[PacketDecoder] Duplicate or incomplete chunks for long packet! Ignoring...");
315 return Err(PacketReadError::IncompleteLongPacket);
316 }
317 if flags.contains(PacketFlags::PacketEnd) {
318 info!("It was last chunk, building packet...");
319 #[cfg(feature="self-tracing")]
320 let g = sparkles_macro::range_event_start!("Building long packet");
321 #[cfg(feature="self-tracing")]
322 let g = sparkles_macro::range_event_start!("Assemple packet");
323 let long_packet_data = udp_state.build();
324 #[cfg(feature="self-tracing")]
325 drop(g);
326 *partial_packet_info = None;
327 if let Some(data) = long_packet_data {
328 #[cfg(feature="self-tracing")]
329 let g = sparkles_macro::range_event_start!("Parse trace packet");
330 let res = parse_packet_from_data(packet_type, &data)?;
331 #[cfg(feature="self-tracing")]
332 drop(g);
333 (Some(res), data_len)
334 }
335 else {
336 warn!("Some chunks of long packet are missing! Skipping...");
337 return Err(PacketReadError::IncompleteLongPacket)
338 }
339 }
340 else {
341 (None, data_len)
342 }
343 }
344 else {
345 if !flags.contains(PacketFlags::PacketStart) {
346 warn!("Assertion failed! Udp packet chunk without start flag!");
347 }
348 *partial_packet_info = Some(UdpParserState::new(chunk_num, data));
349 (None, data_len)
350 }
351 };
352
353
354 if packet_type == PacketType::DataBytes {
355 counters.trace_buf += data_len;
356 }
357 else {
358 counters.secondary_packets += data_len;
359
360 if packet_type == PacketType::GracefulShutdown {
361 *is_eof = true;
362 }
363 }
364 Ok(res)
365 }
366 }
367 }
368
369 pub fn is_eof(&self) -> bool {
370 match self {
371 PacketDecoder::Socket{is_eof, ..} |
372 PacketDecoder::Stream{is_eof, ..} => *is_eof
373 }
374 }
375 pub fn counters(&self) -> ProtocolCounters {
376 match self {
377 PacketDecoder::Socket{counters, ..} |
378 PacketDecoder::Stream{counters, ..} => *counters
379 }
380 }
381}
382
383fn parse_packet_from_data(packet_type: PacketType, data: &[u8]) -> ReadResult<Packet> {
384 match packet_type {
385 PacketType::GracefulShutdown => {
386 Ok(Packet::GracefulShutdown)
387 }
388 PacketType::MachineInfo => {
389 info!("Got MachineInfo packet!");
390 let (machine_info, sz) = bincode::decode_from_slice(data, bincode_config())?;
391 if sz != data.len() {
392 warn!("[PacketDecoder] Assertion failed! MachineInfo packet size mismatch!");
393 }
394 Ok(Packet::MachineInfo(machine_info))
395 }
396 PacketType::DataBytes => {
397 let mut res = Vec::new();
398 let mut cursor = 0;
399 loop {
400 if cursor + 8 >= data.len() {
401 warn!("[PacketDecoder] DataBytes partial data received!");
402 break;
403 }
404 let length = u64::from_be_bytes(data[cursor..cursor + 8].try_into().unwrap()) as usize;
405 cursor += 8;
406 debug!("local packet header len: {length}");
407
408 if cursor + length >= data.len() {
409 warn!("[PacketDecoder] DataBytes partial data received!");
410 break;
411 }
412 let (local_packet_header, sz) = bincode::decode_from_slice(&data[cursor..], bincode_config())?;
413 if sz != length {
414 warn!("[PacketDecoder] Assertion failed! LocalPacketHeader packet size mismatch!");
415 }
416 cursor += length;
417
418 debug!("LocalPacketHeader received!");
419
420 if cursor + 8 >= data.len() {
421 warn!("[PacketDecoder] DataBytes partial data received!");
422 break;
423 }
424 let buf_len = u64::from_be_bytes(data[cursor..cursor + 8].try_into().unwrap()) as usize;
425 cursor += 8;
426
427 debug!("Data len received: {buf_len}");
428
429 if cursor + buf_len > data.len() {
430 warn!("[PacketDecoder] DataBytes partial data received!");
431 break;
432 }
433 let buf = data[cursor..cursor + buf_len].to_vec();
434 res.push((local_packet_header, buf));
435 cursor += buf_len;
436 debug!("Got {}KB of data!", buf_len as f32 / 1024.0);
437
438 if cursor == data.len() {
439 break;
440 }
441 }
442 Ok(Packet::DataBytes(res))
443 }
444 PacketType::FailedPages => {
445 let (failed_pages, sz) = bincode::decode_from_slice(data, bincode_config())?;
446 if sz != data.len() {
447 warn!("[PacketDecoder] Assertion failed! FailedPages packet size mismatch!");
448 }
449 Ok(Packet::FailedPages(failed_pages))
450 }
451 PacketType::TimestampFreq => {
452 let freq = u64::from_be_bytes(data[..8].try_into().unwrap());
453 let cur_tm = u64::from_be_bytes(data[8..16].try_into().unwrap());
454 Ok(Packet::TimestampFreq(freq, cur_tm))
455 }
456 PacketType::ConnectionAccepted => {
457 Ok(Packet::ConnectionAccepted)
458 }
459 }
460}
461
462fn bincode_config() -> impl bincode::config::Config {
463 bincode::config::standard().with_limit::<100_000>()
464}