vtcode_core/gemini/streaming/
processor.rs

1//! Streaming processor for handling real-time responses from the Gemini API
2//!
3//! This module provides functionality to process streaming responses from the Gemini API,
4//! parse them in real-time, and provide callbacks for handling each chunk of data.
5
6use crate::gemini::models::Part;
7use crate::gemini::streaming::{
8    StreamingCandidate, StreamingError, StreamingMetrics, StreamingResponse,
9};
10use futures::stream::StreamExt;
11use reqwest::Response;
12use std::time::Instant;
13use tokio::time::{Duration, timeout};
14
15/// Configuration for the streaming processor
16#[derive(Debug, Clone)]
17pub struct StreamingConfig {
18    /// Timeout for reading each chunk
19    pub chunk_timeout: Duration,
20    /// Maximum time to wait for the first chunk
21    pub first_chunk_timeout: Duration,
22    /// Buffer size for chunk processing
23    pub buffer_size: usize,
24}
25
26impl Default for StreamingConfig {
27    fn default() -> Self {
28        Self {
29            chunk_timeout: Duration::from_secs(30),
30            first_chunk_timeout: Duration::from_secs(60),
31            buffer_size: 1024,
32        }
33    }
34}
35
36/// Streaming processor for handling real-time responses from the Gemini API
37pub struct StreamingProcessor {
38    config: StreamingConfig,
39    metrics: StreamingMetrics,
40}
41
42impl StreamingProcessor {
43    /// Create a new streaming processor with default configuration
44    pub fn new() -> Self {
45        Self {
46            config: StreamingConfig::default(),
47            metrics: StreamingMetrics::default(),
48        }
49    }
50
51    /// Create a new streaming processor with custom configuration
52    pub fn with_config(config: StreamingConfig) -> Self {
53        Self {
54            config,
55            metrics: StreamingMetrics::default(),
56        }
57    }
58
59    /// Process a streaming response from the Gemini API
60    ///
61    /// This method takes a response and processes it in real-time, calling the provided
62    /// callback for each chunk of content received.
63    ///
64    /// # Arguments
65    ///
66    /// * `response` - The HTTP response from the Gemini API
67    /// * `on_chunk` - A callback function that will be called for each text chunk received
68    ///
69    /// # Returns
70    ///
71    /// A result containing the final accumulated response or a streaming error
72    pub async fn process_stream<F>(
73        &mut self,
74        response: Response,
75        mut on_chunk: F,
76    ) -> Result<StreamingResponse, StreamingError>
77    where
78        F: FnMut(&str) -> Result<(), StreamingError>,
79    {
80        self.metrics.request_start_time = Some(Instant::now());
81        self.metrics.total_requests += 1;
82
83        // Get the response stream
84        let mut stream = response.bytes_stream();
85
86        let mut accumulated_response = StreamingResponse {
87            candidates: Vec::new(),
88            usage_metadata: None,
89        };
90
91        let mut _has_valid_content = false;
92        let mut buffer = String::new();
93
94        // Wait for the first chunk with a longer timeout
95        let first_chunk_result = timeout(self.config.first_chunk_timeout, stream.next()).await;
96
97        match first_chunk_result {
98            Ok(Some(Ok(bytes))) => {
99                self.metrics.first_chunk_time = Some(Instant::now());
100                self.metrics.total_bytes += bytes.len();
101
102                // Process the first chunk
103                buffer.push_str(&String::from_utf8_lossy(&bytes));
104                match self.process_buffer(&mut buffer, &mut accumulated_response, &mut on_chunk) {
105                    Ok(valid) => _has_valid_content = valid,
106                    Err(e) => return Err(e),
107                }
108            }
109            Ok(Some(Err(e))) => {
110                self.metrics.error_count += 1;
111                return Err(StreamingError::NetworkError {
112                    message: format!("Failed to read first chunk: {}", e),
113                    is_retryable: true,
114                });
115            }
116            Ok(None) => {
117                return Err(StreamingError::StreamingError {
118                    message: "Empty streaming response".to_string(),
119                    partial_content: None,
120                });
121            }
122            Err(_) => {
123                self.metrics.error_count += 1;
124                return Err(StreamingError::TimeoutError {
125                    operation: "first_chunk".to_string(),
126                    duration: self.config.first_chunk_timeout,
127                });
128            }
129        }
130
131        // Process subsequent chunks
132        while let Some(result) = stream.next().await {
133            match result {
134                Ok(bytes) => {
135                    self.metrics.total_bytes += bytes.len();
136
137                    // Add to buffer
138                    buffer.push_str(&String::from_utf8_lossy(&bytes));
139
140                    // Process buffer
141                    match self.process_buffer(&mut buffer, &mut accumulated_response, &mut on_chunk)
142                    {
143                        Ok(valid) => {
144                            if valid {
145                                _has_valid_content = true;
146                            }
147                        }
148                        Err(e) => return Err(e),
149                    }
150                }
151                Err(e) => {
152                    self.metrics.error_count += 1;
153                    return Err(StreamingError::NetworkError {
154                        message: format!("Failed to read chunk: {}", e),
155                        is_retryable: true,
156                    });
157                }
158            }
159
160            self.metrics.total_chunks += 1;
161        }
162
163        // Process any remaining data in the buffer
164        if !buffer.is_empty() {
165            match self.process_remaining_buffer(
166                &mut buffer,
167                &mut accumulated_response,
168                &mut on_chunk,
169            ) {
170                Ok(valid) => {
171                    if valid {
172                        _has_valid_content = true;
173                    }
174                }
175                Err(e) => return Err(e),
176            }
177        }
178
179        if !_has_valid_content {
180            return Err(StreamingError::ContentError {
181                message: "No valid content received from streaming API".to_string(),
182            });
183        }
184
185        Ok(accumulated_response)
186    }
187
188    /// Process the buffer and extract complete JSON objects
189    fn process_buffer<F>(
190        &mut self,
191        buffer: &mut String,
192        accumulated_response: &mut StreamingResponse,
193        on_chunk: &mut F,
194    ) -> Result<bool, StreamingError>
195    where
196        F: FnMut(&str) -> Result<(), StreamingError>,
197    {
198        let mut _has_valid_content = false;
199        let mut processed_chars = 0;
200
201        // Process complete lines in the buffer
202        while let Some(newline_pos) = buffer[processed_chars..].find('\n') {
203            let line_end = processed_chars + newline_pos;
204            let line = &buffer[processed_chars..line_end].trim();
205            processed_chars = line_end + 1; // +1 to skip the newline
206
207            if line.is_empty() {
208                continue;
209            }
210
211            match self.process_line(line, accumulated_response, on_chunk) {
212                Ok(valid) => {
213                    if valid {
214                        _has_valid_content = true;
215                    }
216                }
217                Err(e) => return Err(e),
218            }
219        }
220
221        // Remove processed content from buffer
222        if processed_chars > 0 {
223            *buffer = buffer[processed_chars..].to_string();
224        }
225
226        Ok(_has_valid_content)
227    }
228
229    /// Process any remaining data in the buffer after streaming is complete
230    fn process_remaining_buffer<F>(
231        &mut self,
232        buffer: &mut String,
233        accumulated_response: &mut StreamingResponse,
234        on_chunk: &mut F,
235    ) -> Result<bool, StreamingError>
236    where
237        F: FnMut(&str) -> Result<(), StreamingError>,
238    {
239        let mut _has_valid_content = false;
240
241        // Process the remaining buffer as a single line
242        let line = buffer.trim();
243        if !line.is_empty() {
244            match self.process_line(line, accumulated_response, on_chunk) {
245                Ok(valid) => {
246                    if valid {
247                        _has_valid_content = true;
248                    }
249                }
250                Err(e) => return Err(e),
251            }
252        }
253
254        // Clear the buffer
255        buffer.clear();
256
257        Ok(_has_valid_content)
258    }
259
260    /// Process a single line of streaming response
261    fn process_line<F>(
262        &mut self,
263        line: &str,
264        accumulated_response: &mut StreamingResponse,
265        on_chunk: &mut F,
266    ) -> Result<bool, StreamingError>
267    where
268        F: FnMut(&str) -> Result<(), StreamingError>,
269    {
270        let mut _has_valid_content = false;
271
272        // Try to parse the line as a JSON object
273        match serde_json::from_str::<StreamingResponse>(line) {
274            Ok(response) => {
275                // Process the response
276                if let Some(candidate) = response.candidates.first() {
277                    match self.process_candidate(candidate, on_chunk) {
278                        Ok(valid) => {
279                            if valid {
280                                _has_valid_content = true;
281                            }
282
283                            // Add to accumulated response
284                            accumulated_response.candidates.extend(response.candidates);
285                            if response.usage_metadata.is_some() {
286                                accumulated_response.usage_metadata = response.usage_metadata;
287                            }
288                        }
289                        Err(e) => return Err(e),
290                    }
291                }
292            }
293            Err(parse_err) => {
294                // If parsing fails, it might be a partial response or non-JSON content
295                // We'll try to extract text content manually
296                if let Some(text) = self.extract_text_from_line(line) {
297                    if !text.trim().is_empty() {
298                        on_chunk(&text)?;
299                        _has_valid_content = true;
300                    }
301                } else {
302                    // Log the parsing error but don't fail immediately
303                    eprintln!(
304                        "Warning: Failed to parse streaming line as JSON: {}",
305                        parse_err
306                    );
307                }
308            }
309        }
310
311        Ok(_has_valid_content)
312    }
313
314    /// Process a streaming candidate and extract content
315    fn process_candidate<F>(
316        &self,
317        candidate: &StreamingCandidate,
318        on_chunk: &mut F,
319    ) -> Result<bool, StreamingError>
320    where
321        F: FnMut(&str) -> Result<(), StreamingError>,
322    {
323        let mut _has_valid_content = false;
324
325        // Process each part of the content
326        for part in &candidate.content.parts {
327            match part {
328                Part::Text { text } => {
329                    if !text.trim().is_empty() {
330                        on_chunk(text)?;
331                        _has_valid_content = true;
332                    }
333                }
334                Part::FunctionCall { .. } => {
335                    // Function calls are handled separately in the tool execution flow
336                    _has_valid_content = true;
337                }
338                Part::FunctionResponse { .. } => {
339                    _has_valid_content = true;
340                }
341            }
342        }
343
344        Ok(_has_valid_content)
345    }
346
347    /// Extract text content from a line that might not be valid JSON
348    fn extract_text_from_line(&self, line: &str) -> Option<String> {
349        // Simple extraction of text content between quotes
350        // This is a fallback for cases where the line isn't valid JSON
351        let mut extracted = String::new();
352        let mut in_quotes = false;
353        let mut escape_next = false;
354        let mut current_field = String::new();
355
356        for ch in line.chars() {
357            if escape_next {
358                current_field.push(ch);
359                escape_next = false;
360                continue;
361            }
362
363            match ch {
364                '\\' => {
365                    escape_next = true;
366                    current_field.push(ch);
367                }
368                '"' => {
369                    if in_quotes {
370                        // End of quoted string
371                        extracted.push_str(&current_field);
372                        current_field.clear();
373                        in_quotes = false;
374                    } else {
375                        // Start of quoted string
376                        current_field.clear();
377                        in_quotes = true;
378                    }
379                }
380                _ => {
381                    if in_quotes {
382                        current_field.push(ch);
383                    }
384                }
385            }
386        }
387
388        if extracted.is_empty() {
389            None
390        } else {
391            Some(extracted)
392        }
393    }
394
395    /// Get current streaming metrics
396    pub fn metrics(&self) -> &StreamingMetrics {
397        &self.metrics
398    }
399
400    /// Reset streaming metrics
401    pub fn reset_metrics(&mut self) {
402        self.metrics = StreamingMetrics::default();
403    }
404}
405
406impl Default for StreamingProcessor {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_streaming_processor_creation() {
418        let processor = StreamingProcessor::new();
419        assert_eq!(processor.metrics().total_requests, 0);
420    }
421
422    #[test]
423    fn test_streaming_processor_with_config() {
424        use std::time::Duration;
425
426        let config = StreamingConfig {
427            chunk_timeout: Duration::from_secs(10),
428            first_chunk_timeout: Duration::from_secs(30),
429            buffer_size: 512,
430        };
431
432        let processor = StreamingProcessor::with_config(config);
433        assert_eq!(processor.metrics().total_requests, 0);
434    }
435
436    #[test]
437    fn test_streaming_config_default() {
438        let config = StreamingConfig::default();
439        assert_eq!(config.buffer_size, 1024);
440    }
441}