rustygpb/
decoder.rs

1//! High-performance Protocol Buffers decoder for FIX messages.
2
3use crate::{
4    buffer::BufferUtils,
5    error::{DecodeError, GpbError},
6    FieldValue, FixMessage, GpbReader, MessageType,
7};
8use fastrace::prelude::*;
9use std::collections::HashMap;
10use std::io::Read;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15/// Wire type indicators for Protocol Buffers.
16#[derive(Debug, Clone, Copy)]
17#[repr(u8)]
18enum WireType {
19    Varint = 0,
20    Fixed64 = 1,
21    LengthDelimited = 2,
22    Fixed32 = 5,
23}
24
25/// Configuration for GPB decoding operations.
26#[derive(Debug, Clone)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28pub struct DecodeConfig {
29    /// Validate checksums if present
30    pub validate_checksums: bool,
31    /// Validate message structure
32    pub validate_structure: bool,
33    /// Maximum message size to accept
34    pub max_message_size: usize,
35    /// Strict field validation
36    pub strict_field_validation: bool,
37    /// Zero-copy string optimization
38    pub zero_copy_strings: bool,
39}
40
41impl Default for DecodeConfig {
42    fn default() -> Self {
43        Self {
44            validate_checksums: true,
45            validate_structure: true,
46            max_message_size: 1024 * 1024, // 1MB
47            strict_field_validation: true,
48            zero_copy_strings: false, // Disabled by default for safety
49        }
50    }
51}
52
53/// High-performance Protocol Buffers decoder for FIX messages.
54#[derive(Debug)]
55pub struct GpbDecoder {
56    config: DecodeConfig,
57    /// Reverse field mappings (GPB field -> FIX tag)
58    reverse_field_mappings: HashMap<u32, u32>,
59}
60
61impl GpbDecoder {
62    /// Create a new decoder with default configuration.
63    pub fn new() -> Self {
64        Self::with_config(DecodeConfig::default())
65    }
66
67    /// Create decoder with custom configuration.
68    pub fn with_config(config: DecodeConfig) -> Self {
69        Self {
70            config,
71            reverse_field_mappings: Self::create_reverse_field_mappings(),
72        }
73    }
74
75    /// Decode a Protocol Buffers message to FIX format.
76    #[trace]
77    pub fn decode(&self, data: &[u8]) -> Result<FixMessage, GpbError> {
78        // Validate input size
79        if data.len() > self.config.max_message_size {
80            return Err(GpbError::Decode(DecodeError::TruncatedBuffer {
81                expected: self.config.max_message_size,
82                actual: data.len(),
83            }));
84        }
85
86        let mut reader = GpbReader::new(data)?;
87
88        // Start with default message
89        let mut message_type = MessageType::Heartbeat;
90        let mut seq_num = None;
91        let mut sender_comp_id = None;
92        let mut target_comp_id = None;
93        let mut sending_time = None;
94        let mut fields = HashMap::new();
95        let mut checksum_validated = false;
96
97        // Decode all fields
98        while reader.has_remaining() {
99            let tag_and_wire = self.decode_varint(&mut reader)?;
100            let field_num = (tag_and_wire >> 3) as u32;
101            let wire_type = (tag_and_wire & 0x07) as u8;
102
103            match field_num {
104                1 => {
105                    // Message Type
106                    if wire_type != WireType::LengthDelimited as u8 {
107                        return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
108                            reason: format!("Invalid wire type for message type: {}", wire_type),
109                        }));
110                    }
111                    let msg_type_str = self.decode_string(&mut reader)?;
112                    message_type = MessageType::from_str(&msg_type_str);
113                }
114                2 => {
115                    // Sequence Number
116                    if wire_type != WireType::Varint as u8 {
117                        return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
118                            reason: format!("Invalid wire type for sequence number: {}", wire_type),
119                        }));
120                    }
121                    seq_num = Some(self.decode_varint(&mut reader)? as u32);
122                }
123                3 => {
124                    // Sender CompID
125                    if wire_type != WireType::LengthDelimited as u8 {
126                        return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
127                            reason: format!("Invalid wire type for sender CompID: {}", wire_type),
128                        }));
129                    }
130                    sender_comp_id = Some(self.decode_string(&mut reader)?);
131                }
132                4 => {
133                    // Target CompID
134                    if wire_type != WireType::LengthDelimited as u8 {
135                        return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
136                            reason: format!("Invalid wire type for target CompID: {}", wire_type),
137                        }));
138                    }
139                    target_comp_id = Some(self.decode_string(&mut reader)?);
140                }
141                5 => {
142                    // Sending Time
143                    if wire_type != WireType::Varint as u8 {
144                        return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
145                            reason: format!("Invalid wire type for sending time: {}", wire_type),
146                        }));
147                    }
148                    sending_time = Some(self.decode_varint(&mut reader)?);
149                }
150                999 => {
151                    // Checksum field
152                    if wire_type != WireType::Fixed32 as u8 {
153                        return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
154                            reason: format!("Invalid wire type for checksum: {}", wire_type),
155                        }));
156                    }
157                    let checksum = self.decode_fixed32(&mut reader)?;
158
159                    if self.config.validate_checksums {
160                        // Calculate checksum of data excluding the checksum field
161                        let data_without_checksum = &data[..data.len() - 6]; // Assume 6 bytes for checksum field
162                        let calculated_checksum = BufferUtils::crc32(data_without_checksum);
163
164                        if checksum != calculated_checksum {
165                            return Err(GpbError::Decode(DecodeError::ChecksumMismatch {
166                                expected: checksum,
167                                actual: calculated_checksum,
168                            }));
169                        }
170                        checksum_validated = true;
171                    }
172                }
173                0 => {
174                    // Batch header - skip for now
175                    self.skip_field(&mut reader, wire_type)?;
176                }
177                _ => {
178                    // Regular FIX field
179                    let fix_tag = self.map_gpb_field_to_fix_tag(field_num);
180                    let field_value = self.decode_field_value(&mut reader, wire_type)?;
181                    fields.insert(fix_tag, field_value);
182                }
183            }
184        }
185
186        // Validate checksum was present if required
187        if self.config.validate_checksums && !checksum_validated {
188            return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
189                reason: "Missing required checksum".to_string(),
190            }));
191        }
192
193        // Construct the message
194        let mut message = FixMessage::new(message_type);
195        message.seq_num = seq_num;
196        message.sender_comp_id = sender_comp_id;
197        message.target_comp_id = target_comp_id;
198        message.sending_time = sending_time;
199        message.fields = fields;
200
201        // Validate message structure if configured
202        if self.config.validate_structure {
203            message.validate().map_err(GpbError::Encode)?;
204        }
205
206        Ok(message)
207    }
208
209    /// Decode a batch of messages.
210    #[trace]
211    pub fn decode_batch(&self, data: &[u8]) -> Result<Vec<FixMessage>, GpbError> {
212        let mut reader = GpbReader::new(data)?;
213        let mut messages = Vec::new();
214
215        // First field should be batch count
216        let tag_and_wire = self.decode_varint(&mut reader)?;
217        let field_num = (tag_and_wire >> 3) as u32;
218        let wire_type = (tag_and_wire & 0x07) as u8;
219
220        if field_num != 0 || wire_type != WireType::Varint as u8 {
221            return Err(GpbError::Decode(DecodeError::InvalidWireFormat {
222                reason: "Invalid batch header".to_string(),
223            }));
224        }
225
226        let message_count = self.decode_varint(&mut reader)? as usize;
227        messages.reserve(message_count);
228
229        // Decode each message
230        for _ in 0..message_count {
231            if !reader.has_remaining() {
232                break;
233            }
234
235            // Read message length
236            let msg_length = self.decode_varint(&mut reader)? as usize;
237
238            // Extract message data
239            let mut msg_data = vec![0u8; msg_length];
240            reader.read_exact(&mut msg_data).map_err(|_| {
241                GpbError::Decode(DecodeError::TruncatedBuffer {
242                    expected: msg_length,
243                    actual: reader.remaining(),
244                })
245            })?;
246
247            // Decode the individual message with relaxed checksum validation
248            // (individual messages in batch don't have checksums)
249            let mut batch_config = self.config.clone();
250            batch_config.validate_checksums = false;
251            let batch_decoder = GpbDecoder::with_config(batch_config);
252            let message = batch_decoder.decode(&msg_data)?;
253            messages.push(message);
254        }
255
256        Ok(messages)
257    }
258
259    /// Decode field value based on wire type.
260    #[trace]
261    fn decode_field_value(
262        &self,
263        reader: &mut GpbReader,
264        wire_type: u8,
265    ) -> Result<FieldValue, DecodeError> {
266        match wire_type {
267            t if t == WireType::Varint as u8 => {
268                let value = self.decode_varint(reader)?;
269                // Try to determine if this is signed (zigzag encoded)
270                // Check if the zigzag decode produces a negative number that makes sense
271                let zigzag_decoded = self.decode_zigzag(value);
272
273                // If the zigzag decoding produces a negative number or the value is odd
274                // (which indicates a negative number in zigzag encoding), treat as signed
275                if zigzag_decoded < 0 || (value & 1) == 1 {
276                    Ok(FieldValue::Int(zigzag_decoded))
277                } else {
278                    Ok(FieldValue::UInt(value))
279                }
280            }
281            t if t == WireType::Fixed64 as u8 => {
282                let value = self.decode_double(reader)?;
283                Ok(FieldValue::Float(value))
284            }
285            t if t == WireType::LengthDelimited as u8 => {
286                let bytes = self.decode_bytes(reader)?;
287
288                // Try to decode as UTF-8 string first
289                match String::from_utf8(bytes.clone()) {
290                    Ok(string) => Ok(FieldValue::String(string)),
291                    Err(_) => {
292                        // Check if it might be a decimal (embedded message)
293                        if let Ok(decimal) = self.try_decode_decimal(&bytes) {
294                            Ok(decimal)
295                        } else {
296                            Ok(FieldValue::Bytes(bytes))
297                        }
298                    }
299                }
300            }
301            t if t == WireType::Fixed32 as u8 => {
302                let value = self.decode_fixed32(reader)?;
303                Ok(FieldValue::UInt(value as u64))
304            }
305            _ => Err(DecodeError::InvalidWireFormat {
306                reason: format!("Unknown wire type: {}", wire_type),
307            }),
308        }
309    }
310
311    /// Decode variable-length integer.
312    fn decode_varint(&self, reader: &mut GpbReader) -> Result<u64, DecodeError> {
313        BufferUtils::decode_varint(reader)
314    }
315
316    /// Decode zigzag-encoded signed integer.
317    fn decode_zigzag(&self, value: u64) -> i64 {
318        ((value >> 1) as i64) ^ (-((value & 1) as i64))
319    }
320
321    /// Decode string with length prefix.
322    fn decode_string(&self, reader: &mut GpbReader) -> Result<String, DecodeError> {
323        let bytes = self.decode_bytes(reader)?;
324        String::from_utf8(bytes).map_err(|e| DecodeError::InvalidWireFormat {
325            reason: format!("Invalid UTF-8: {}", e),
326        })
327    }
328
329    /// Decode byte array with length prefix.
330    fn decode_bytes(&self, reader: &mut GpbReader) -> Result<Vec<u8>, DecodeError> {
331        let length = self.decode_varint(reader)? as usize;
332
333        if length > reader.remaining() {
334            return Err(DecodeError::TruncatedBuffer {
335                expected: length,
336                actual: reader.remaining(),
337            });
338        }
339
340        let mut bytes = vec![0u8; length];
341        reader
342            .read_exact(&mut bytes)
343            .map_err(|_| DecodeError::TruncatedBuffer {
344                expected: length,
345                actual: reader.remaining(),
346            })?;
347
348        Ok(bytes)
349    }
350
351    /// Decode 64-bit double.
352    fn decode_double(&self, reader: &mut GpbReader) -> Result<f64, DecodeError> {
353        let mut bytes = [0u8; 8];
354        reader
355            .read_exact(&mut bytes)
356            .map_err(|_| DecodeError::TruncatedBuffer {
357                expected: 8,
358                actual: reader.remaining(),
359            })?;
360
361        Ok(f64::from_le_bytes(bytes))
362    }
363
364    /// Decode 32-bit fixed integer.
365    fn decode_fixed32(&self, reader: &mut GpbReader) -> Result<u32, DecodeError> {
366        let mut bytes = [0u8; 4];
367        reader
368            .read_exact(&mut bytes)
369            .map_err(|_| DecodeError::TruncatedBuffer {
370                expected: 4,
371                actual: reader.remaining(),
372            })?;
373
374        Ok(u32::from_le_bytes(bytes))
375    }
376
377    /// Try to decode bytes as a decimal embedded message.
378    fn try_decode_decimal(&self, bytes: &[u8]) -> Result<FieldValue, DecodeError> {
379        let mut reader = GpbReader::new(bytes).map_err(|e| match e {
380            GpbError::Io(_io_err) => DecodeError::TruncatedBuffer {
381                expected: bytes.len(),
382                actual: 0,
383            },
384            _ => DecodeError::InvalidWireFormat {
385                reason: format!("Failed to create reader: {}", e),
386            },
387        })?;
388        let mut mantissa = None;
389        let mut scale = None;
390
391        while reader.has_remaining() {
392            let tag_and_wire = self.decode_varint(&mut reader)?;
393            let field_num = (tag_and_wire >> 3) as u32;
394            let wire_type = (tag_and_wire & 0x07) as u8;
395
396            if wire_type != WireType::Varint as u8 {
397                return Err(DecodeError::InvalidWireFormat {
398                    reason: "Invalid decimal field wire type".to_string(),
399                });
400            }
401
402            match field_num {
403                1 => {
404                    // Mantissa
405                    let value = self.decode_varint(&mut reader)?;
406                    mantissa = Some(self.decode_zigzag(value));
407                }
408                2 => {
409                    // Scale
410                    let value = self.decode_varint(&mut reader)?;
411                    scale = Some(self.decode_zigzag(value) as i32);
412                }
413                _ => {
414                    // Unknown field in decimal - skip
415                    self.skip_field(&mut reader, wire_type)?;
416                }
417            }
418        }
419
420        match (mantissa, scale) {
421            (Some(m), Some(s)) => Ok(FieldValue::Decimal {
422                mantissa: m,
423                scale: s,
424            }),
425            _ => Err(DecodeError::InvalidWireFormat {
426                reason: "Incomplete decimal fields".to_string(),
427            }),
428        }
429    }
430
431    /// Skip a field based on wire type.
432    fn skip_field(&self, reader: &mut GpbReader, wire_type: u8) -> Result<(), DecodeError> {
433        match wire_type {
434            t if t == WireType::Varint as u8 => {
435                self.decode_varint(reader)?;
436            }
437            t if t == WireType::Fixed64 as u8 => {
438                let mut bytes = [0u8; 8];
439                reader
440                    .read_exact(&mut bytes)
441                    .map_err(|_| DecodeError::TruncatedBuffer {
442                        expected: 8,
443                        actual: reader.remaining(),
444                    })?;
445            }
446            t if t == WireType::LengthDelimited as u8 => {
447                let length = self.decode_varint(reader)? as usize;
448                let mut bytes = vec![0u8; length];
449                reader
450                    .read_exact(&mut bytes)
451                    .map_err(|_| DecodeError::TruncatedBuffer {
452                        expected: length,
453                        actual: reader.remaining(),
454                    })?;
455            }
456            t if t == WireType::Fixed32 as u8 => {
457                let mut bytes = [0u8; 4];
458                reader
459                    .read_exact(&mut bytes)
460                    .map_err(|_| DecodeError::TruncatedBuffer {
461                        expected: 4,
462                        actual: reader.remaining(),
463                    })?;
464            }
465            _ => {
466                return Err(DecodeError::InvalidWireFormat {
467                    reason: format!("Unknown wire type to skip: {}", wire_type),
468                });
469            }
470        }
471
472        Ok(())
473    }
474
475    /// Map GPB field number back to FIX tag.
476    fn map_gpb_field_to_fix_tag(&self, gpb_field: u32) -> u32 {
477        self.reverse_field_mappings
478            .get(&gpb_field)
479            .copied()
480            .unwrap_or(gpb_field.saturating_sub(100)) // Reverse the offset
481    }
482
483    /// Create reverse field mappings (GPB field -> FIX tag).
484    fn create_reverse_field_mappings() -> HashMap<u32, u32> {
485        let mut mappings = HashMap::new();
486
487        // Reverse the encoder mappings
488        mappings.insert(10, 8); // 10 -> BeginString
489        mappings.insert(11, 9); // 11 -> BodyLength
490        mappings.insert(12, 35); // 12 -> MsgType
491        mappings.insert(13, 34); // 13 -> MsgSeqNum
492        mappings.insert(14, 49); // 14 -> SenderCompID
493        mappings.insert(15, 56); // 15 -> TargetCompID
494        mappings.insert(16, 52); // 16 -> SendingTime
495        mappings.insert(20, 55); // 20 -> Symbol
496        mappings.insert(21, 44); // 21 -> Price
497        mappings.insert(22, 38); // 22 -> OrderQty
498        mappings.insert(23, 54); // 23 -> Side
499        mappings.insert(24, 40); // 24 -> OrdType
500        mappings.insert(25, 59); // 25 -> TimeInForce
501        mappings.insert(30, 37); // 30 -> OrderID
502        mappings.insert(31, 17); // 31 -> ExecID
503        mappings.insert(32, 150); // 32 -> ExecType
504        mappings.insert(33, 39); // 33 -> OrdStatus
505        mappings.insert(34, 32); // 34 -> LastQty
506        mappings.insert(35, 31); // 35 -> LastPx
507
508        mappings
509    }
510
511    /// Get decoder statistics for monitoring.
512    pub fn stats(&self) -> DecoderStats {
513        DecoderStats {
514            reverse_mappings_count: self.reverse_field_mappings.len(),
515            config: self.config.clone(),
516        }
517    }
518}
519
520impl Default for GpbDecoder {
521    fn default() -> Self {
522        Self::new()
523    }
524}
525
526/// Decoder performance and usage statistics.
527#[derive(Debug, Clone)]
528pub struct DecoderStats {
529    /// Number of reverse field mappings
530    pub reverse_mappings_count: usize,
531    /// Current decoder configuration
532    pub config: DecodeConfig,
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use crate::{GpbEncoder, MessageType};
539
540    #[test]
541    fn test_decoder_creation() {
542        let decoder = GpbDecoder::new();
543        let stats = decoder.stats();
544        assert!(stats.reverse_mappings_count > 0);
545    }
546
547    #[test]
548    fn test_encode_decode_round_trip() {
549        let mut encoder = GpbEncoder::new();
550        let decoder = GpbDecoder::new();
551
552        let original_message =
553            FixMessage::new_order_single("BTCUSD".to_string(), 50000.0, 1.5, "1".to_string());
554
555        // Encode
556        let encoded = encoder.encode(&original_message).unwrap();
557
558        // Decode
559        let decoded_message = decoder.decode(encoded).unwrap();
560
561        // Verify round trip
562        assert_eq!(original_message.message_type, decoded_message.message_type);
563        assert_eq!(original_message.fields.len(), decoded_message.fields.len());
564
565        // Check specific fields
566        assert_eq!(
567            original_message.get_field(55).unwrap().as_string(),
568            decoded_message.get_field(55).unwrap().as_string()
569        );
570        assert_eq!(
571            original_message.get_field(44).unwrap().as_float(),
572            decoded_message.get_field(44).unwrap().as_float()
573        );
574    }
575
576    #[test]
577    fn test_decode_with_checksum_validation() {
578        let decoder = GpbDecoder::with_config(DecodeConfig {
579            validate_checksums: true,
580            ..Default::default()
581        });
582
583        let mut encoder = GpbEncoder::new();
584        let message =
585            FixMessage::new_order_single("ETHUSD".to_string(), 3000.0, 2.0, "2".to_string());
586
587        let encoded = encoder.encode(&message).unwrap();
588        let decoded = decoder.decode(encoded).unwrap();
589
590        assert_eq!(message.message_type, decoded.message_type);
591    }
592
593    #[test]
594    fn test_decode_different_field_types() {
595        let mut encoder = GpbEncoder::new();
596        let decoder = GpbDecoder::new();
597
598        let mut message = FixMessage::new(MessageType::Heartbeat);
599        message.set_field(1, FieldValue::String("test".to_string()));
600        message.set_field(2, FieldValue::Int(-123));
601        message.set_field(3, FieldValue::UInt(456));
602        message.set_field(4, FieldValue::Float(123.45));
603        message.set_field(5, FieldValue::Bool(true));
604        message.set_field(
605            6,
606            FieldValue::Decimal {
607                mantissa: 12345,
608                scale: 2,
609            },
610        );
611
612        let encoded = encoder.encode(&message).unwrap();
613        let decoded = decoder.decode(encoded).unwrap();
614
615        assert_eq!(message.fields.len(), decoded.fields.len());
616
617        // Verify specific field types
618        assert!(matches!(
619            decoded.get_field(1).unwrap(),
620            FieldValue::String(_)
621        ));
622        assert!(matches!(decoded.get_field(2).unwrap(), FieldValue::Int(_)));
623        assert!(matches!(decoded.get_field(3).unwrap(), FieldValue::UInt(_)));
624        assert!(matches!(
625            decoded.get_field(4).unwrap(),
626            FieldValue::Float(_)
627        ));
628        assert!(matches!(
629            decoded.get_field(6).unwrap(),
630            FieldValue::Decimal { .. }
631        ));
632    }
633
634    #[test]
635    fn test_invalid_data_handling() {
636        let decoder = GpbDecoder::new();
637
638        // Test with invalid data
639        let invalid_data = b"this is not protobuf data";
640        let result = decoder.decode(invalid_data);
641        assert!(result.is_err());
642
643        // Test with truncated data
644        let truncated_data = b"\x08\x96\x01"; // Incomplete varint
645        let _result = decoder.decode(truncated_data);
646        // This might succeed or fail depending on the specific bytes
647        // The important thing is it doesn't panic
648    }
649
650    #[test]
651    fn test_batch_decode() {
652        let mut encoder = GpbEncoder::new();
653        let decoder = GpbDecoder::new();
654
655        let messages = vec![
656            FixMessage::new_order_single("BTC".to_string(), 50000.0, 1.0, "1".to_string()),
657            FixMessage::new_order_single("ETH".to_string(), 3000.0, 2.0, "2".to_string()),
658        ];
659
660        let encoded = encoder.encode_batch(&messages).unwrap();
661        let decoded_messages = decoder.decode_batch(encoded).unwrap();
662
663        assert_eq!(messages.len(), decoded_messages.len());
664
665        for (original, decoded) in messages.iter().zip(decoded_messages.iter()) {
666            assert_eq!(original.message_type, decoded.message_type);
667        }
668    }
669}