1use crate::{
4 buffer::BufferUtils,
5 error::{DecodeError, GpbError},
6 FieldValue, FixMessage, GpbReader, MessageType,
7};
8use fastrace::prelude::*;
9use std::collections::HashMap;
10use std::io::Read;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Copy)]
17#[repr(u8)]
18enum WireType {
19 Varint = 0,
20 Fixed64 = 1,
21 LengthDelimited = 2,
22 Fixed32 = 5,
23}
24
25#[derive(Debug, Clone)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28pub struct DecodeConfig {
29 pub validate_checksums: bool,
31 pub validate_structure: bool,
33 pub max_message_size: usize,
35 pub strict_field_validation: bool,
37 pub zero_copy_strings: bool,
39}
40
41impl Default for DecodeConfig {
42 fn default() -> Self {
43 Self {
44 validate_checksums: true,
45 validate_structure: true,
46 max_message_size: 1024 * 1024, strict_field_validation: true,
48 zero_copy_strings: false, }
50 }
51}
52
53#[derive(Debug)]
55pub struct GpbDecoder {
56 config: DecodeConfig,
57 reverse_field_mappings: HashMap<u32, u32>,
59}
60
61impl GpbDecoder {
62 pub fn new() -> Self {
64 Self::with_config(DecodeConfig::default())
65 }
66
67 pub fn with_config(config: DecodeConfig) -> Self {
69 Self {
70 config,
71 reverse_field_mappings: Self::create_reverse_field_mappings(),
72 }
73 }
74
75 #[trace]
77 pub fn decode(&self, data: &[u8]) -> Result<FixMessage, GpbError> {
78 if data.len() > self.config.max_message_size {
80 return Err(GpbError::Decode(DecodeError::TruncatedBuffer {
81 expected: self.config.max_message_size,
82 actual: data.len(),
83 }));
84 }
85
86 let mut reader = GpbReader::new(data)?;
87
88 let mut message_type = MessageType::Heartbeat;
90 let mut seq_num = None;
91 let mut sender_comp_id = None;
92 let mut target_comp_id = None;
93 let mut sending_time = None;
94 let mut fields = HashMap::new();
95 let mut checksum_validated = false;
96
97 while reader.has_remaining() {
99 let tag_and_wire = self.decode_varint(&mut reader)?;
100 let field_num = (tag_and_wire >> 3) as u32;
101 let wire_type = (tag_and_wire & 0x07) as u8;
102
103 match field_num {
104 1 => {
105 if wire_type != WireType::LengthDelimited as u8 {
107 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
108 reason: format!("Invalid wire type for message type: {}", wire_type),
109 }));
110 }
111 let msg_type_str = self.decode_string(&mut reader)?;
112 message_type = MessageType::from_str(&msg_type_str);
113 }
114 2 => {
115 if wire_type != WireType::Varint as u8 {
117 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
118 reason: format!("Invalid wire type for sequence number: {}", wire_type),
119 }));
120 }
121 seq_num = Some(self.decode_varint(&mut reader)? as u32);
122 }
123 3 => {
124 if wire_type != WireType::LengthDelimited as u8 {
126 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
127 reason: format!("Invalid wire type for sender CompID: {}", wire_type),
128 }));
129 }
130 sender_comp_id = Some(self.decode_string(&mut reader)?);
131 }
132 4 => {
133 if wire_type != WireType::LengthDelimited as u8 {
135 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
136 reason: format!("Invalid wire type for target CompID: {}", wire_type),
137 }));
138 }
139 target_comp_id = Some(self.decode_string(&mut reader)?);
140 }
141 5 => {
142 if wire_type != WireType::Varint as u8 {
144 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
145 reason: format!("Invalid wire type for sending time: {}", wire_type),
146 }));
147 }
148 sending_time = Some(self.decode_varint(&mut reader)?);
149 }
150 999 => {
151 if wire_type != WireType::Fixed32 as u8 {
153 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
154 reason: format!("Invalid wire type for checksum: {}", wire_type),
155 }));
156 }
157 let checksum = self.decode_fixed32(&mut reader)?;
158
159 if self.config.validate_checksums {
160 let data_without_checksum = &data[..data.len() - 6]; let calculated_checksum = BufferUtils::crc32(data_without_checksum);
163
164 if checksum != calculated_checksum {
165 return Err(GpbError::Decode(DecodeError::ChecksumMismatch {
166 expected: checksum,
167 actual: calculated_checksum,
168 }));
169 }
170 checksum_validated = true;
171 }
172 }
173 0 => {
174 self.skip_field(&mut reader, wire_type)?;
176 }
177 _ => {
178 let fix_tag = self.map_gpb_field_to_fix_tag(field_num);
180 let field_value = self.decode_field_value(&mut reader, wire_type)?;
181 fields.insert(fix_tag, field_value);
182 }
183 }
184 }
185
186 if self.config.validate_checksums && !checksum_validated {
188 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
189 reason: "Missing required checksum".to_string(),
190 }));
191 }
192
193 let mut message = FixMessage::new(message_type);
195 message.seq_num = seq_num;
196 message.sender_comp_id = sender_comp_id;
197 message.target_comp_id = target_comp_id;
198 message.sending_time = sending_time;
199 message.fields = fields;
200
201 if self.config.validate_structure {
203 message.validate().map_err(GpbError::Encode)?;
204 }
205
206 Ok(message)
207 }
208
209 #[trace]
211 pub fn decode_batch(&self, data: &[u8]) -> Result<Vec<FixMessage>, GpbError> {
212 let mut reader = GpbReader::new(data)?;
213 let mut messages = Vec::new();
214
215 let tag_and_wire = self.decode_varint(&mut reader)?;
217 let field_num = (tag_and_wire >> 3) as u32;
218 let wire_type = (tag_and_wire & 0x07) as u8;
219
220 if field_num != 0 || wire_type != WireType::Varint as u8 {
221 return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
222 reason: "Invalid batch header".to_string(),
223 }));
224 }
225
226 let message_count = self.decode_varint(&mut reader)? as usize;
227 messages.reserve(message_count);
228
229 for _ in 0..message_count {
231 if !reader.has_remaining() {
232 break;
233 }
234
235 let msg_length = self.decode_varint(&mut reader)? as usize;
237
238 let mut msg_data = vec![0u8; msg_length];
240 reader.read_exact(&mut msg_data).map_err(|_| {
241 GpbError::Decode(DecodeError::TruncatedBuffer {
242 expected: msg_length,
243 actual: reader.remaining(),
244 })
245 })?;
246
247 let mut batch_config = self.config.clone();
250 batch_config.validate_checksums = false;
251 let batch_decoder = GpbDecoder::with_config(batch_config);
252 let message = batch_decoder.decode(&msg_data)?;
253 messages.push(message);
254 }
255
256 Ok(messages)
257 }
258
259 #[trace]
261 fn decode_field_value(
262 &self,
263 reader: &mut GpbReader,
264 wire_type: u8,
265 ) -> Result<FieldValue, DecodeError> {
266 match wire_type {
267 t if t == WireType::Varint as u8 => {
268 let value = self.decode_varint(reader)?;
269 let zigzag_decoded = self.decode_zigzag(value);
272
273 if zigzag_decoded < 0 || (value & 1) == 1 {
276 Ok(FieldValue::Int(zigzag_decoded))
277 } else {
278 Ok(FieldValue::UInt(value))
279 }
280 }
281 t if t == WireType::Fixed64 as u8 => {
282 let value = self.decode_double(reader)?;
283 Ok(FieldValue::Float(value))
284 }
285 t if t == WireType::LengthDelimited as u8 => {
286 let bytes = self.decode_bytes(reader)?;
287
288 match String::from_utf8(bytes.clone()) {
290 Ok(string) => Ok(FieldValue::String(string)),
291 Err(_) => {
292 if let Ok(decimal) = self.try_decode_decimal(&bytes) {
294 Ok(decimal)
295 } else {
296 Ok(FieldValue::Bytes(bytes))
297 }
298 }
299 }
300 }
301 t if t == WireType::Fixed32 as u8 => {
302 let value = self.decode_fixed32(reader)?;
303 Ok(FieldValue::UInt(value as u64))
304 }
305 _ => Err(DecodeError::InvalidWireFormat {
306 reason: format!("Unknown wire type: {}", wire_type),
307 }),
308 }
309 }
310
311 fn decode_varint(&self, reader: &mut GpbReader) -> Result<u64, DecodeError> {
313 BufferUtils::decode_varint(reader)
314 }
315
316 fn decode_zigzag(&self, value: u64) -> i64 {
318 ((value >> 1) as i64) ^ (-((value & 1) as i64))
319 }
320
321 fn decode_string(&self, reader: &mut GpbReader) -> Result<String, DecodeError> {
323 let bytes = self.decode_bytes(reader)?;
324 String::from_utf8(bytes).map_err(|e| DecodeError::InvalidWireFormat {
325 reason: format!("Invalid UTF-8: {}", e),
326 })
327 }
328
329 fn decode_bytes(&self, reader: &mut GpbReader) -> Result<Vec<u8>, DecodeError> {
331 let length = self.decode_varint(reader)? as usize;
332
333 if length > reader.remaining() {
334 return Err(DecodeError::TruncatedBuffer {
335 expected: length,
336 actual: reader.remaining(),
337 });
338 }
339
340 let mut bytes = vec![0u8; length];
341 reader
342 .read_exact(&mut bytes)
343 .map_err(|_| DecodeError::TruncatedBuffer {
344 expected: length,
345 actual: reader.remaining(),
346 })?;
347
348 Ok(bytes)
349 }
350
351 fn decode_double(&self, reader: &mut GpbReader) -> Result<f64, DecodeError> {
353 let mut bytes = [0u8; 8];
354 reader
355 .read_exact(&mut bytes)
356 .map_err(|_| DecodeError::TruncatedBuffer {
357 expected: 8,
358 actual: reader.remaining(),
359 })?;
360
361 Ok(f64::from_le_bytes(bytes))
362 }
363
364 fn decode_fixed32(&self, reader: &mut GpbReader) -> Result<u32, DecodeError> {
366 let mut bytes = [0u8; 4];
367 reader
368 .read_exact(&mut bytes)
369 .map_err(|_| DecodeError::TruncatedBuffer {
370 expected: 4,
371 actual: reader.remaining(),
372 })?;
373
374 Ok(u32::from_le_bytes(bytes))
375 }
376
377 fn try_decode_decimal(&self, bytes: &[u8]) -> Result<FieldValue, DecodeError> {
379 let mut reader = GpbReader::new(bytes).map_err(|e| match e {
380 GpbError::Io(_io_err) => DecodeError::TruncatedBuffer {
381 expected: bytes.len(),
382 actual: 0,
383 },
384 _ => DecodeError::InvalidWireFormat {
385 reason: format!("Failed to create reader: {}", e),
386 },
387 })?;
388 let mut mantissa = None;
389 let mut scale = None;
390
391 while reader.has_remaining() {
392 let tag_and_wire = self.decode_varint(&mut reader)?;
393 let field_num = (tag_and_wire >> 3) as u32;
394 let wire_type = (tag_and_wire & 0x07) as u8;
395
396 if wire_type != WireType::Varint as u8 {
397 return Err(DecodeError::InvalidWireFormat {
398 reason: "Invalid decimal field wire type".to_string(),
399 });
400 }
401
402 match field_num {
403 1 => {
404 let value = self.decode_varint(&mut reader)?;
406 mantissa = Some(self.decode_zigzag(value));
407 }
408 2 => {
409 let value = self.decode_varint(&mut reader)?;
411 scale = Some(self.decode_zigzag(value) as i32);
412 }
413 _ => {
414 self.skip_field(&mut reader, wire_type)?;
416 }
417 }
418 }
419
420 match (mantissa, scale) {
421 (Some(m), Some(s)) => Ok(FieldValue::Decimal {
422 mantissa: m,
423 scale: s,
424 }),
425 _ => Err(DecodeError::InvalidWireFormat {
426 reason: "Incomplete decimal fields".to_string(),
427 }),
428 }
429 }
430
431 fn skip_field(&self, reader: &mut GpbReader, wire_type: u8) -> Result<(), DecodeError> {
433 match wire_type {
434 t if t == WireType::Varint as u8 => {
435 self.decode_varint(reader)?;
436 }
437 t if t == WireType::Fixed64 as u8 => {
438 let mut bytes = [0u8; 8];
439 reader
440 .read_exact(&mut bytes)
441 .map_err(|_| DecodeError::TruncatedBuffer {
442 expected: 8,
443 actual: reader.remaining(),
444 })?;
445 }
446 t if t == WireType::LengthDelimited as u8 => {
447 let length = self.decode_varint(reader)? as usize;
448 let mut bytes = vec![0u8; length];
449 reader
450 .read_exact(&mut bytes)
451 .map_err(|_| DecodeError::TruncatedBuffer {
452 expected: length,
453 actual: reader.remaining(),
454 })?;
455 }
456 t if t == WireType::Fixed32 as u8 => {
457 let mut bytes = [0u8; 4];
458 reader
459 .read_exact(&mut bytes)
460 .map_err(|_| DecodeError::TruncatedBuffer {
461 expected: 4,
462 actual: reader.remaining(),
463 })?;
464 }
465 _ => {
466 return Err(DecodeError::InvalidWireFormat {
467 reason: format!("Unknown wire type to skip: {}", wire_type),
468 });
469 }
470 }
471
472 Ok(())
473 }
474
475 fn map_gpb_field_to_fix_tag(&self, gpb_field: u32) -> u32 {
477 self.reverse_field_mappings
478 .get(&gpb_field)
479 .copied()
480 .unwrap_or(gpb_field.saturating_sub(100)) }
482
483 fn create_reverse_field_mappings() -> HashMap<u32, u32> {
485 let mut mappings = HashMap::new();
486
487 mappings.insert(10, 8); mappings.insert(11, 9); mappings.insert(12, 35); mappings.insert(13, 34); mappings.insert(14, 49); mappings.insert(15, 56); mappings.insert(16, 52); mappings.insert(20, 55); mappings.insert(21, 44); mappings.insert(22, 38); mappings.insert(23, 54); mappings.insert(24, 40); mappings.insert(25, 59); mappings.insert(30, 37); mappings.insert(31, 17); mappings.insert(32, 150); mappings.insert(33, 39); mappings.insert(34, 32); mappings.insert(35, 31); mappings
509 }
510
511 pub fn stats(&self) -> DecoderStats {
513 DecoderStats {
514 reverse_mappings_count: self.reverse_field_mappings.len(),
515 config: self.config.clone(),
516 }
517 }
518}
519
520impl Default for GpbDecoder {
521 fn default() -> Self {
522 Self::new()
523 }
524}
525
526#[derive(Debug, Clone)]
528pub struct DecoderStats {
529 pub reverse_mappings_count: usize,
531 pub config: DecodeConfig,
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use crate::{GpbEncoder, MessageType};
539
540 #[test]
541 fn test_decoder_creation() {
542 let decoder = GpbDecoder::new();
543 let stats = decoder.stats();
544 assert!(stats.reverse_mappings_count > 0);
545 }
546
547 #[test]
548 fn test_encode_decode_round_trip() {
549 let mut encoder = GpbEncoder::new();
550 let decoder = GpbDecoder::new();
551
552 let original_message =
553 FixMessage::new_order_single("BTCUSD".to_string(), 50000.0, 1.5, "1".to_string());
554
555 let encoded = encoder.encode(&original_message).unwrap();
557
558 let decoded_message = decoder.decode(encoded).unwrap();
560
561 assert_eq!(original_message.message_type, decoded_message.message_type);
563 assert_eq!(original_message.fields.len(), decoded_message.fields.len());
564
565 assert_eq!(
567 original_message.get_field(55).unwrap().as_string(),
568 decoded_message.get_field(55).unwrap().as_string()
569 );
570 assert_eq!(
571 original_message.get_field(44).unwrap().as_float(),
572 decoded_message.get_field(44).unwrap().as_float()
573 );
574 }
575
576 #[test]
577 fn test_decode_with_checksum_validation() {
578 let decoder = GpbDecoder::with_config(DecodeConfig {
579 validate_checksums: true,
580 ..Default::default()
581 });
582
583 let mut encoder = GpbEncoder::new();
584 let message =
585 FixMessage::new_order_single("ETHUSD".to_string(), 3000.0, 2.0, "2".to_string());
586
587 let encoded = encoder.encode(&message).unwrap();
588 let decoded = decoder.decode(encoded).unwrap();
589
590 assert_eq!(message.message_type, decoded.message_type);
591 }
592
593 #[test]
594 fn test_decode_different_field_types() {
595 let mut encoder = GpbEncoder::new();
596 let decoder = GpbDecoder::new();
597
598 let mut message = FixMessage::new(MessageType::Heartbeat);
599 message.set_field(1, FieldValue::String("test".to_string()));
600 message.set_field(2, FieldValue::Int(-123));
601 message.set_field(3, FieldValue::UInt(456));
602 message.set_field(4, FieldValue::Float(123.45));
603 message.set_field(5, FieldValue::Bool(true));
604 message.set_field(
605 6,
606 FieldValue::Decimal {
607 mantissa: 12345,
608 scale: 2,
609 },
610 );
611
612 let encoded = encoder.encode(&message).unwrap();
613 let decoded = decoder.decode(encoded).unwrap();
614
615 assert_eq!(message.fields.len(), decoded.fields.len());
616
617 assert!(matches!(
619 decoded.get_field(1).unwrap(),
620 FieldValue::String(_)
621 ));
622 assert!(matches!(decoded.get_field(2).unwrap(), FieldValue::Int(_)));
623 assert!(matches!(decoded.get_field(3).unwrap(), FieldValue::UInt(_)));
624 assert!(matches!(
625 decoded.get_field(4).unwrap(),
626 FieldValue::Float(_)
627 ));
628 assert!(matches!(
629 decoded.get_field(6).unwrap(),
630 FieldValue::Decimal { .. }
631 ));
632 }
633
634 #[test]
635 fn test_invalid_data_handling() {
636 let decoder = GpbDecoder::new();
637
638 let invalid_data = b"this is not protobuf data";
640 let result = decoder.decode(invalid_data);
641 assert!(result.is_err());
642
643 let truncated_data = b"\x08\x96\x01"; let _result = decoder.decode(truncated_data);
646 }
649
650 #[test]
651 fn test_batch_decode() {
652 let mut encoder = GpbEncoder::new();
653 let decoder = GpbDecoder::new();
654
655 let messages = vec![
656 FixMessage::new_order_single("BTC".to_string(), 50000.0, 1.0, "1".to_string()),
657 FixMessage::new_order_single("ETH".to_string(), 3000.0, 2.0, "2".to_string()),
658 ];
659
660 let encoded = encoder.encode_batch(&messages).unwrap();
661 let decoded_messages = decoder.decode_batch(encoded).unwrap();
662
663 assert_eq!(messages.len(), decoded_messages.len());
664
665 for (original, decoded) in messages.iter().zip(decoded_messages.iter()) {
666 assert_eq!(original.message_type, decoded.message_type);
667 }
668 }
669}