vtcode_core/gemini/streaming/
processor.rs

1//! Streaming processor for handling real-time responses from the Gemini API
2//!
3//! This module provides functionality to process streaming responses from the Gemini API,
4//! parse them in real-time, and provide callbacks for handling each chunk of data.
5
6use crate::gemini::models::{Content, Part};
7use crate::gemini::streaming::{
8    StreamingCandidate, StreamingError, StreamingMetrics, StreamingResponse,
9};
10use futures::stream::StreamExt;
11use reqwest::Response;
12use serde_json::Value;
13use std::time::Instant;
14use tokio::time::{Duration, timeout};
15
16/// Configuration for the streaming processor
17#[derive(Debug, Clone)]
18pub struct StreamingConfig {
19    /// Timeout for reading each chunk
20    pub chunk_timeout: Duration,
21    /// Maximum time to wait for the first chunk
22    pub first_chunk_timeout: Duration,
23    /// Buffer size for chunk processing
24    pub buffer_size: usize,
25}
26
27impl Default for StreamingConfig {
28    fn default() -> Self {
29        Self {
30            chunk_timeout: Duration::from_secs(30),
31            first_chunk_timeout: Duration::from_secs(60),
32            buffer_size: 1024,
33        }
34    }
35}
36
37/// Streaming processor for handling real-time responses from the Gemini API
38pub struct StreamingProcessor {
39    config: StreamingConfig,
40    metrics: StreamingMetrics,
41    current_event_data: String,
42}
43
44impl StreamingProcessor {
45    /// Create a new streaming processor with default configuration
46    pub fn new() -> Self {
47        Self {
48            config: StreamingConfig::default(),
49            metrics: StreamingMetrics::default(),
50            current_event_data: String::new(),
51        }
52    }
53
54    /// Create a new streaming processor with custom configuration
55    pub fn with_config(config: StreamingConfig) -> Self {
56        Self {
57            config,
58            metrics: StreamingMetrics::default(),
59            current_event_data: String::new(),
60        }
61    }
62
63    /// Process a streaming response from the Gemini API
64    ///
65    /// This method takes a response and processes it in real-time, calling the provided
66    /// callback for each chunk of content received.
67    ///
68    /// # Arguments
69    ///
70    /// * `response` - The HTTP response from the Gemini API
71    /// * `on_chunk` - A callback function that will be called for each text chunk received
72    ///
73    /// # Returns
74    ///
75    /// A result containing the final accumulated response or a streaming error
76    pub async fn process_stream<F>(
77        &mut self,
78        response: Response,
79        mut on_chunk: F,
80    ) -> Result<StreamingResponse, StreamingError>
81    where
82        F: FnMut(&str) -> Result<(), StreamingError>,
83    {
84        self.metrics.request_start_time = Some(Instant::now());
85        self.metrics.total_requests += 1;
86        self.current_event_data.clear();
87
88        // Get the response stream
89        let mut stream = response.bytes_stream();
90
91        let mut accumulated_response = StreamingResponse {
92            candidates: Vec::new(),
93            usage_metadata: None,
94        };
95
96        let mut _has_valid_content = false;
97        let mut buffer = String::new();
98
99        // Wait for the first chunk with a longer timeout
100        let first_chunk_result = timeout(self.config.first_chunk_timeout, stream.next()).await;
101
102        match first_chunk_result {
103            Ok(Some(Ok(bytes))) => {
104                self.metrics.first_chunk_time = Some(Instant::now());
105                self.metrics.total_bytes += bytes.len();
106
107                // Process the first chunk
108                buffer.push_str(&String::from_utf8_lossy(&bytes));
109                match self.process_buffer(&mut buffer, &mut accumulated_response, &mut on_chunk) {
110                    Ok(valid) => _has_valid_content = valid,
111                    Err(e) => return Err(e),
112                }
113            }
114            Ok(Some(Err(e))) => {
115                self.metrics.error_count += 1;
116                return Err(StreamingError::NetworkError {
117                    message: format!("Failed to read first chunk: {}", e),
118                    is_retryable: true,
119                });
120            }
121            Ok(None) => {
122                return Err(StreamingError::StreamingError {
123                    message: "Empty streaming response".to_string(),
124                    partial_content: None,
125                });
126            }
127            Err(_) => {
128                self.metrics.error_count += 1;
129                return Err(StreamingError::TimeoutError {
130                    operation: "first_chunk".to_string(),
131                    duration: self.config.first_chunk_timeout,
132                });
133            }
134        }
135
136        // Process subsequent chunks
137        while let Some(result) = stream.next().await {
138            match result {
139                Ok(bytes) => {
140                    self.metrics.total_bytes += bytes.len();
141
142                    // Add to buffer
143                    buffer.push_str(&String::from_utf8_lossy(&bytes));
144
145                    // Process buffer
146                    match self.process_buffer(&mut buffer, &mut accumulated_response, &mut on_chunk)
147                    {
148                        Ok(valid) => {
149                            if valid {
150                                _has_valid_content = true;
151                            }
152                        }
153                        Err(e) => return Err(e),
154                    }
155                }
156                Err(e) => {
157                    self.metrics.error_count += 1;
158                    return Err(StreamingError::NetworkError {
159                        message: format!("Failed to read chunk: {}", e),
160                        is_retryable: true,
161                    });
162                }
163            }
164
165            self.metrics.total_chunks += 1;
166        }
167
168        // Process any remaining data in the buffer
169        if !buffer.is_empty() {
170            match self.process_remaining_buffer(
171                &mut buffer,
172                &mut accumulated_response,
173                &mut on_chunk,
174            ) {
175                Ok(valid) => {
176                    if valid {
177                        _has_valid_content = true;
178                    }
179                }
180                Err(e) => return Err(e),
181            }
182        }
183
184        if !_has_valid_content {
185            return Err(StreamingError::ContentError {
186                message: "No valid content received from streaming API".to_string(),
187            });
188        }
189
190        Ok(accumulated_response)
191    }
192
193    /// Process the buffer and extract complete SSE events
194    fn process_buffer<F>(
195        &mut self,
196        buffer: &mut String,
197        accumulated_response: &mut StreamingResponse,
198        on_chunk: &mut F,
199    ) -> Result<bool, StreamingError>
200    where
201        F: FnMut(&str) -> Result<(), StreamingError>,
202    {
203        let mut _has_valid_content = false;
204        let mut processed_chars = 0;
205
206        while let Some(newline_pos) = buffer[processed_chars..].find('\n') {
207            let line_end = processed_chars + newline_pos;
208            let line = &buffer[processed_chars..line_end];
209            processed_chars = line_end + 1;
210
211            match self.handle_line(line, accumulated_response, on_chunk) {
212                Ok(valid) => {
213                    if valid {
214                        _has_valid_content = true;
215                    }
216                }
217                Err(e) => return Err(e),
218            }
219        }
220
221        if processed_chars > 0 {
222            *buffer = buffer[processed_chars..].to_string();
223        }
224
225        Ok(_has_valid_content)
226    }
227
228    /// Process any remaining data in the buffer after streaming is complete
229    fn process_remaining_buffer<F>(
230        &mut self,
231        buffer: &mut String,
232        accumulated_response: &mut StreamingResponse,
233        on_chunk: &mut F,
234    ) -> Result<bool, StreamingError>
235    where
236        F: FnMut(&str) -> Result<(), StreamingError>,
237    {
238        let mut _has_valid_content = false;
239
240        if !buffer.is_empty() {
241            let remaining_line = buffer.trim_end_matches('\r');
242            if !remaining_line.trim().is_empty() {
243                match self.handle_line(remaining_line, accumulated_response, on_chunk) {
244                    Ok(valid) => {
245                        if valid {
246                            _has_valid_content = true;
247                        }
248                    }
249                    Err(e) => return Err(e),
250                }
251            }
252        }
253
254        buffer.clear();
255
256        match self.finalize_current_event(accumulated_response, on_chunk) {
257            Ok(valid) => {
258                if valid {
259                    _has_valid_content = true;
260                }
261            }
262            Err(e) => return Err(e),
263        }
264
265        Ok(_has_valid_content)
266    }
267
268    /// Handle a single SSE line
269    fn handle_line<F>(
270        &mut self,
271        raw_line: &str,
272        accumulated_response: &mut StreamingResponse,
273        on_chunk: &mut F,
274    ) -> Result<bool, StreamingError>
275    where
276        F: FnMut(&str) -> Result<(), StreamingError>,
277    {
278        let mut _has_valid_content = false;
279        let line = raw_line.trim_end_matches('\r');
280
281        if line.is_empty() {
282            match self.finalize_current_event(accumulated_response, on_chunk) {
283                Ok(valid) => {
284                    if valid {
285                        _has_valid_content = true;
286                    }
287                }
288                Err(e) => return Err(e),
289            }
290            return Ok(_has_valid_content);
291        }
292
293        let trimmed = line.trim();
294
295        if trimmed.is_empty() {
296            return Ok(false);
297        }
298
299        if trimmed.starts_with(':') {
300            return Ok(false);
301        }
302
303        if trimmed.starts_with("event:") || trimmed.starts_with("id:") {
304            return Ok(false);
305        }
306
307        if trimmed.starts_with("data:") {
308            let data_segment = trimmed[5..].trim_start();
309            if data_segment == "[DONE]" {
310                match self.finalize_current_event(accumulated_response, on_chunk) {
311                    Ok(valid) => {
312                        if valid {
313                            _has_valid_content = true;
314                        }
315                    }
316                    Err(e) => return Err(e),
317                }
318                return Ok(_has_valid_content);
319            }
320
321            if !data_segment.is_empty() {
322                if !self.current_event_data.is_empty() {
323                    self.current_event_data.push('\n');
324                }
325                self.current_event_data.push_str(data_segment);
326
327                match self.try_flush_current_event(accumulated_response, on_chunk) {
328                    Ok(valid) => {
329                        if valid {
330                            _has_valid_content = true;
331                        }
332                    }
333                    Err(e) => return Err(e),
334                }
335            }
336            return Ok(_has_valid_content);
337        }
338
339        if trimmed.starts_with('{') || trimmed.starts_with('[') {
340            if !self.current_event_data.is_empty() {
341                self.current_event_data.push('\n');
342            }
343            self.current_event_data.push_str(trimmed);
344            return Ok(false);
345        }
346
347        if !self.current_event_data.is_empty() {
348            self.current_event_data.push('\n');
349        }
350        self.current_event_data.push_str(trimmed);
351
352        Ok(false)
353    }
354
355    fn finalize_current_event<F>(
356        &mut self,
357        accumulated_response: &mut StreamingResponse,
358        on_chunk: &mut F,
359    ) -> Result<bool, StreamingError>
360    where
361        F: FnMut(&str) -> Result<(), StreamingError>,
362    {
363        if self.current_event_data.trim().is_empty() {
364            self.current_event_data.clear();
365            return Ok(false);
366        }
367
368        let event_data = std::mem::take(&mut self.current_event_data);
369        self.process_event(event_data, accumulated_response, on_chunk)
370    }
371
372    fn try_flush_current_event<F>(
373        &mut self,
374        accumulated_response: &mut StreamingResponse,
375        on_chunk: &mut F,
376    ) -> Result<bool, StreamingError>
377    where
378        F: FnMut(&str) -> Result<(), StreamingError>,
379    {
380        let trimmed = self.current_event_data.trim();
381        if trimmed.is_empty() {
382            return Ok(false);
383        }
384
385        match serde_json::from_str::<Value>(trimmed) {
386            Ok(parsed) => {
387                self.current_event_data.clear();
388                self.process_event_value(parsed, accumulated_response, on_chunk)
389            }
390            Err(parse_err) => {
391                if parse_err.is_eof() {
392                    return Ok(false);
393                }
394
395                Err(StreamingError::ParseError {
396                    message: format!("Failed to parse streaming JSON: {}", parse_err),
397                    raw_response: trimmed.to_string(),
398                })
399            }
400        }
401    }
402
403    fn process_event<F>(
404        &mut self,
405        event_data: String,
406        accumulated_response: &mut StreamingResponse,
407        on_chunk: &mut F,
408    ) -> Result<bool, StreamingError>
409    where
410        F: FnMut(&str) -> Result<(), StreamingError>,
411    {
412        let trimmed = event_data.trim();
413
414        if trimmed.is_empty() {
415            return Ok(false);
416        }
417
418        match serde_json::from_str::<Value>(trimmed) {
419            Ok(parsed) => self.process_event_value(parsed, accumulated_response, on_chunk),
420            Err(parse_err) => {
421                if parse_err.is_eof() {
422                    self.current_event_data = trimmed.to_string();
423                    return Ok(false);
424                }
425
426                Err(StreamingError::ParseError {
427                    message: format!("Failed to parse streaming JSON: {}", parse_err),
428                    raw_response: trimmed.to_string(),
429                })
430            }
431        }
432    }
433
434    fn append_text_candidate(&mut self, accumulated_response: &mut StreamingResponse, text: &str) {
435        if text.is_empty() {
436            return;
437        }
438
439        if let Some(last_candidate) = accumulated_response.candidates.last_mut() {
440            Self::merge_parts(
441                &mut last_candidate.content.parts,
442                vec![Part::Text {
443                    text: text.to_string(),
444                }],
445            );
446            return;
447        }
448
449        let index = accumulated_response.candidates.len();
450
451        accumulated_response.candidates.push(StreamingCandidate {
452            content: Content {
453                role: "model".to_string(),
454                parts: vec![Part::Text {
455                    text: text.to_string(),
456                }],
457            },
458            finish_reason: None,
459            index: Some(index),
460        });
461    }
462
463    /// Process a streaming candidate and extract content
464    fn process_candidate<F>(
465        &self,
466        candidate: &StreamingCandidate,
467        on_chunk: &mut F,
468    ) -> Result<bool, StreamingError>
469    where
470        F: FnMut(&str) -> Result<(), StreamingError>,
471    {
472        let mut _has_valid_content = false;
473
474        // Process each part of the content
475        for part in &candidate.content.parts {
476            match part {
477                Part::Text { text } => {
478                    if !text.trim().is_empty() {
479                        on_chunk(text)?;
480                        _has_valid_content = true;
481                    }
482                }
483                Part::FunctionCall { .. } => {
484                    // Function calls are handled separately in the tool execution flow
485                    _has_valid_content = true;
486                }
487                Part::FunctionResponse { .. } => {
488                    _has_valid_content = true;
489                }
490            }
491        }
492
493        Ok(_has_valid_content)
494    }
495
496    fn process_event_value<F>(
497        &mut self,
498        value: Value,
499        accumulated_response: &mut StreamingResponse,
500        on_chunk: &mut F,
501    ) -> Result<bool, StreamingError>
502    where
503        F: FnMut(&str) -> Result<(), StreamingError>,
504    {
505        match value {
506            Value::Array(items) => {
507                let mut has_valid = false;
508                for item in items {
509                    if self.process_event_value(item, accumulated_response, on_chunk)? {
510                        has_valid = true;
511                    }
512                }
513                Ok(has_valid)
514            }
515            Value::Object(map) => {
516                if let Some(error_value) = map.get("error") {
517                    let message = error_value
518                        .get("message")
519                        .and_then(Value::as_str)
520                        .unwrap_or("Gemini streaming error")
521                        .to_string();
522                    let code = error_value
523                        .get("code")
524                        .and_then(Value::as_i64)
525                        .unwrap_or(500) as u16;
526                    return Err(StreamingError::ApiError {
527                        status_code: code,
528                        message,
529                        is_retryable: code == 429,
530                    });
531                }
532
533                if let Some(usage) = map.get("usageMetadata") {
534                    accumulated_response.usage_metadata = Some(usage.clone());
535                }
536
537                let mut has_valid = false;
538
539                if let Some(candidates_value) = map.get("candidates") {
540                    let candidate_values: Vec<Value> = match candidates_value {
541                        Value::Array(items) => items.clone(),
542                        Value::Object(_) => vec![candidates_value.clone()],
543                        _ => Vec::new(),
544                    };
545
546                    for candidate_value in candidate_values {
547                        match serde_json::from_value::<StreamingCandidate>(candidate_value.clone())
548                        {
549                            Ok(candidate) => {
550                                if self.process_candidate(&candidate, on_chunk)? {
551                                    has_valid = true;
552                                }
553                                self.merge_candidate(accumulated_response, candidate);
554                            }
555                            Err(err) => {
556                                if let Some(text) = Self::extract_text_from_value(&candidate_value)
557                                {
558                                    if !text.trim().is_empty() {
559                                        on_chunk(&text)?;
560                                        self.append_text_candidate(accumulated_response, &text);
561                                        has_valid = true;
562                                    }
563                                } else {
564                                    return Err(StreamingError::ParseError {
565                                        message: format!("Failed to parse candidate: {}", err),
566                                        raw_response: candidate_value.to_string(),
567                                    });
568                                }
569                            }
570                        }
571                    }
572                }
573
574                if let Some(text_value) = map.get("text").and_then(Value::as_str) {
575                    if !text_value.trim().is_empty() {
576                        on_chunk(text_value)?;
577                        self.append_text_candidate(accumulated_response, text_value);
578                        has_valid = true;
579                    }
580                }
581
582                Ok(has_valid)
583            }
584            Value::String(text) => {
585                if text.trim().is_empty() {
586                    Ok(false)
587                } else {
588                    on_chunk(&text)?;
589                    self.append_text_candidate(accumulated_response, &text);
590                    Ok(true)
591                }
592            }
593            _ => Ok(false),
594        }
595    }
596
597    fn merge_candidate(
598        &mut self,
599        accumulated_response: &mut StreamingResponse,
600        mut candidate: StreamingCandidate,
601    ) {
602        let index = candidate
603            .index
604            .unwrap_or_else(|| accumulated_response.candidates.len());
605
606        if let Some(existing) = accumulated_response
607            .candidates
608            .iter_mut()
609            .find(|existing| existing.index.unwrap_or(index) == index)
610        {
611            if existing.content.role.is_empty() {
612                existing.content.role = candidate.content.role.clone();
613            }
614
615            Self::merge_parts(&mut existing.content.parts, candidate.content.parts);
616
617            if candidate.finish_reason.is_some() {
618                existing.finish_reason = candidate.finish_reason;
619            }
620        } else {
621            candidate.index = Some(index);
622            accumulated_response.candidates.push(candidate);
623        }
624    }
625
626    fn merge_parts(target: &mut Vec<Part>, source_parts: Vec<Part>) {
627        if target.is_empty() {
628            *target = source_parts;
629            return;
630        }
631
632        for part in source_parts {
633            match (target.last_mut(), &part) {
634                (Some(Part::Text { text: existing }), Part::Text { text: new_text }) => {
635                    existing.push_str(new_text);
636                }
637                _ => target.push(part),
638            }
639        }
640    }
641
642    fn extract_text_from_value(value: &Value) -> Option<String> {
643        match value {
644            Value::String(text) => {
645                if text.trim().is_empty() {
646                    None
647                } else {
648                    Some(text.clone())
649                }
650            }
651            Value::Array(items) => {
652                let mut collected = String::new();
653                for item in items {
654                    if let Some(fragment) = Self::extract_text_from_value(item) {
655                        collected.push_str(&fragment);
656                    }
657                }
658                if collected.is_empty() {
659                    None
660                } else {
661                    Some(collected)
662                }
663            }
664            Value::Object(map) => {
665                if let Some(text) = map.get("text").and_then(Value::as_str) {
666                    if !text.trim().is_empty() {
667                        return Some(text.to_string());
668                    }
669                }
670
671                if let Some(parts) = map.get("parts").and_then(Value::as_array) {
672                    if let Some(parts_text) =
673                        Self::extract_text_from_value(&Value::Array(parts.clone()))
674                    {
675                        return Some(parts_text);
676                    }
677                }
678
679                for nested in map.values() {
680                    if let Some(nested_text) = Self::extract_text_from_value(nested) {
681                        if !nested_text.trim().is_empty() {
682                            return Some(nested_text);
683                        }
684                    }
685                }
686
687                None
688            }
689            _ => None,
690        }
691    }
692
693    /// Get current streaming metrics
694    pub fn metrics(&self) -> &StreamingMetrics {
695        &self.metrics
696    }
697
698    /// Reset streaming metrics
699    pub fn reset_metrics(&mut self) {
700        self.metrics = StreamingMetrics::default();
701    }
702}
703
704impl Default for StreamingProcessor {
705    fn default() -> Self {
706        Self::new()
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713
714    #[test]
715    fn test_streaming_processor_creation() {
716        let processor = StreamingProcessor::new();
717        assert_eq!(processor.metrics().total_requests, 0);
718    }
719
720    #[test]
721    fn test_streaming_processor_with_config() {
722        use std::time::Duration;
723
724        let config = StreamingConfig {
725            chunk_timeout: Duration::from_secs(10),
726            first_chunk_timeout: Duration::from_secs(30),
727            buffer_size: 512,
728        };
729
730        let processor = StreamingProcessor::with_config(config);
731        assert_eq!(processor.metrics().total_requests, 0);
732    }
733
734    #[test]
735    fn test_streaming_config_default() {
736        let config = StreamingConfig::default();
737        assert_eq!(config.buffer_size, 1024);
738    }
739
740    #[test]
741    fn test_handles_back_to_back_data_lines_without_blank_lines() {
742        let mut processor = StreamingProcessor::new();
743        let mut accumulated = StreamingResponse {
744            candidates: Vec::new(),
745            usage_metadata: None,
746        };
747        let mut received_chunks: Vec<String> = Vec::new();
748        let mut buffer = String::from(
749            "data: {\"candidates\":[{\"index\":0,\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"Hello\"}]}}]}\n",
750        );
751        buffer.push_str(
752            "data: {\"candidates\":[{\"index\":0,\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\" world\"}]}}]}\n",
753        );
754
755        {
756            let mut on_chunk = |chunk: &str| {
757                received_chunks.push(chunk.to_string());
758                Ok(())
759            };
760            let has_valid = processor
761                .process_buffer(&mut buffer, &mut accumulated, &mut on_chunk)
762                .expect("processing should succeed");
763            assert!(has_valid);
764        }
765
766        assert_eq!(received_chunks, vec!["Hello", " world"]);
767        assert_eq!(accumulated.candidates.len(), 1);
768        let combined = match &accumulated.candidates[0].content.parts[0] {
769            Part::Text { text } => text.clone(),
770            _ => String::new(),
771        };
772        assert_eq!(combined, "Hello world");
773    }
774}