Skip to main content

statsig_rust/specs_response/
proto_stream_reader.rs

1use std::io::Read;
2
3use crate::{
4    networking::{ResponseData, ResponseDataStream},
5    StatsigErr,
6};
7use brotli::Decompressor;
8use bytes::BytesMut;
9
10pub const BUFFER_SIZE: usize = 4096;
11
12pub struct ProtoStreamReader<'a> {
13    brotli_decompressor: Decompressor<StreamBorrower<'a>>,
14
15    scratch: [u8; BUFFER_SIZE],
16    buf: BytesMut,
17}
18
19impl<'a> ProtoStreamReader<'a> {
20    pub fn new(data: &'a mut ResponseData) -> Self {
21        let stream_borrower = StreamBorrower::new(data);
22        let brotli_decompressor = Decompressor::new(stream_borrower, BUFFER_SIZE);
23
24        Self {
25            brotli_decompressor,
26            scratch: [0u8; BUFFER_SIZE],
27            buf: BytesMut::new(),
28        }
29    }
30
31    pub fn read_next_delimited_proto(&mut self) -> Result<BytesMut, StatsigErr> {
32        let required_len = self.read_length_delimiter()?;
33
34        while self.buf.len() < required_len {
35            match self.brotli_decompressor.read(&mut self.scratch) {
36                Ok(0) => {
37                    return Ok(self.buf.split_to(required_len));
38                }
39                Ok(n) => {
40                    self.buf.extend_from_slice(&self.scratch[..n]);
41                }
42                Err(e) => {
43                    return Err(StatsigErr::ProtobufParseError(
44                        "BrotliDecompressorRead".to_string(),
45                        e.to_string(),
46                    ));
47                }
48            }
49        }
50
51        Ok(self.buf.split_to(required_len))
52    }
53
54    pub fn sample_current_buf(&self) -> String {
55        let len = std::cmp::min(self.buf.len(), 100);
56        let slice = &self.buf.as_ref()[..len];
57        String::from_utf8(slice.to_vec()).unwrap_or_default()
58    }
59
60    fn read_length_delimiter(&mut self) -> Result<usize, StatsigErr> {
61        loop {
62            match prost::decode_length_delimiter(self.buf.as_ref()) {
63                Ok(data_len) => {
64                    return Ok(prost::length_delimiter_len(data_len) + data_len);
65                }
66                Err(e) if self.buf.len() >= 10 => {
67                    return Err(StatsigErr::ProtobufParseError(
68                        "DecodeLengthDelimiter".to_string(),
69                        e.to_string(),
70                    ));
71                }
72                Err(_) => {
73                    let read_len =
74                        self.brotli_decompressor
75                            .read(&mut self.scratch)
76                            .map_err(|e| {
77                                StatsigErr::ProtobufParseError(
78                                    "ReadLengthDelimiter".to_string(),
79                                    e.to_string(),
80                                )
81                            })?;
82
83                    if read_len == 0 {
84                        return Err(StatsigErr::ProtobufParseError(
85                            "ReadLengthDelimiter".to_string(),
86                            "unexpected EOF while reading length delimiter".to_string(),
87                        ));
88                    }
89
90                    self.buf.extend_from_slice(&self.scratch[..read_len]);
91                }
92            }
93        }
94    }
95}
96
97struct StreamBorrower<'a> {
98    stream: &'a mut dyn ResponseDataStream,
99}
100
101impl<'a> StreamBorrower<'a> {
102    pub fn new(data: &'a mut ResponseData) -> Self {
103        Self {
104            stream: data.get_stream_mut(),
105        }
106    }
107}
108
109impl<'a> std::io::Read for StreamBorrower<'a> {
110    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
111        self.stream.read(buf)
112    }
113}