1use crate::gemini::models::{Content, Part};
7use crate::gemini::streaming::{
8 StreamingCandidate, StreamingError, StreamingMetrics, StreamingResponse,
9};
10use futures::stream::StreamExt;
11use reqwest::Response;
12use serde_json::Value;
13use std::time::Instant;
14use tokio::time::{Duration, timeout};
15
16#[derive(Debug, Clone)]
18pub struct StreamingConfig {
19 pub chunk_timeout: Duration,
21 pub first_chunk_timeout: Duration,
23 pub buffer_size: usize,
25}
26
27impl Default for StreamingConfig {
28 fn default() -> Self {
29 Self {
30 chunk_timeout: Duration::from_secs(30),
31 first_chunk_timeout: Duration::from_secs(60),
32 buffer_size: 1024,
33 }
34 }
35}
36
37pub struct StreamingProcessor {
39 config: StreamingConfig,
40 metrics: StreamingMetrics,
41 current_event_data: String,
42}
43
44impl StreamingProcessor {
45 pub fn new() -> Self {
47 Self {
48 config: StreamingConfig::default(),
49 metrics: StreamingMetrics::default(),
50 current_event_data: String::new(),
51 }
52 }
53
54 pub fn with_config(config: StreamingConfig) -> Self {
56 Self {
57 config,
58 metrics: StreamingMetrics::default(),
59 current_event_data: String::new(),
60 }
61 }
62
63 pub async fn process_stream<F>(
77 &mut self,
78 response: Response,
79 mut on_chunk: F,
80 ) -> Result<StreamingResponse, StreamingError>
81 where
82 F: FnMut(&str) -> Result<(), StreamingError>,
83 {
84 self.metrics.request_start_time = Some(Instant::now());
85 self.metrics.total_requests += 1;
86 self.current_event_data.clear();
87
88 let mut stream = response.bytes_stream();
90
91 let mut accumulated_response = StreamingResponse {
92 candidates: Vec::new(),
93 usage_metadata: None,
94 };
95
96 let mut _has_valid_content = false;
97 let mut buffer = String::new();
98
99 let first_chunk_result = timeout(self.config.first_chunk_timeout, stream.next()).await;
101
102 match first_chunk_result {
103 Ok(Some(Ok(bytes))) => {
104 self.metrics.first_chunk_time = Some(Instant::now());
105 self.metrics.total_bytes += bytes.len();
106
107 buffer.push_str(&String::from_utf8_lossy(&bytes));
109 match self.process_buffer(&mut buffer, &mut accumulated_response, &mut on_chunk) {
110 Ok(valid) => _has_valid_content = valid,
111 Err(e) => return Err(e),
112 }
113 }
114 Ok(Some(Err(e))) => {
115 self.metrics.error_count += 1;
116 return Err(StreamingError::NetworkError {
117 message: format!("Failed to read first chunk: {}", e),
118 is_retryable: true,
119 });
120 }
121 Ok(None) => {
122 return Err(StreamingError::StreamingError {
123 message: "Empty streaming response".to_string(),
124 partial_content: None,
125 });
126 }
127 Err(_) => {
128 self.metrics.error_count += 1;
129 return Err(StreamingError::TimeoutError {
130 operation: "first_chunk".to_string(),
131 duration: self.config.first_chunk_timeout,
132 });
133 }
134 }
135
136 while let Some(result) = stream.next().await {
138 match result {
139 Ok(bytes) => {
140 self.metrics.total_bytes += bytes.len();
141
142 buffer.push_str(&String::from_utf8_lossy(&bytes));
144
145 match self.process_buffer(&mut buffer, &mut accumulated_response, &mut on_chunk)
147 {
148 Ok(valid) => {
149 if valid {
150 _has_valid_content = true;
151 }
152 }
153 Err(e) => return Err(e),
154 }
155 }
156 Err(e) => {
157 self.metrics.error_count += 1;
158 return Err(StreamingError::NetworkError {
159 message: format!("Failed to read chunk: {}", e),
160 is_retryable: true,
161 });
162 }
163 }
164
165 self.metrics.total_chunks += 1;
166 }
167
168 if !buffer.is_empty() {
170 match self.process_remaining_buffer(
171 &mut buffer,
172 &mut accumulated_response,
173 &mut on_chunk,
174 ) {
175 Ok(valid) => {
176 if valid {
177 _has_valid_content = true;
178 }
179 }
180 Err(e) => return Err(e),
181 }
182 }
183
184 if !_has_valid_content {
185 return Err(StreamingError::ContentError {
186 message: "No valid content received from streaming API".to_string(),
187 });
188 }
189
190 Ok(accumulated_response)
191 }
192
193 fn process_buffer<F>(
195 &mut self,
196 buffer: &mut String,
197 accumulated_response: &mut StreamingResponse,
198 on_chunk: &mut F,
199 ) -> Result<bool, StreamingError>
200 where
201 F: FnMut(&str) -> Result<(), StreamingError>,
202 {
203 let mut _has_valid_content = false;
204 let mut processed_chars = 0;
205
206 while let Some(newline_pos) = buffer[processed_chars..].find('\n') {
207 let line_end = processed_chars + newline_pos;
208 let line = &buffer[processed_chars..line_end];
209 processed_chars = line_end + 1;
210
211 match self.handle_line(line, accumulated_response, on_chunk) {
212 Ok(valid) => {
213 if valid {
214 _has_valid_content = true;
215 }
216 }
217 Err(e) => return Err(e),
218 }
219 }
220
221 if processed_chars > 0 {
222 *buffer = buffer[processed_chars..].to_string();
223 }
224
225 Ok(_has_valid_content)
226 }
227
228 fn process_remaining_buffer<F>(
230 &mut self,
231 buffer: &mut String,
232 accumulated_response: &mut StreamingResponse,
233 on_chunk: &mut F,
234 ) -> Result<bool, StreamingError>
235 where
236 F: FnMut(&str) -> Result<(), StreamingError>,
237 {
238 let mut _has_valid_content = false;
239
240 if !buffer.is_empty() {
241 let remaining_line = buffer.trim_end_matches('\r');
242 if !remaining_line.trim().is_empty() {
243 match self.handle_line(remaining_line, accumulated_response, on_chunk) {
244 Ok(valid) => {
245 if valid {
246 _has_valid_content = true;
247 }
248 }
249 Err(e) => return Err(e),
250 }
251 }
252 }
253
254 buffer.clear();
255
256 match self.finalize_current_event(accumulated_response, on_chunk) {
257 Ok(valid) => {
258 if valid {
259 _has_valid_content = true;
260 }
261 }
262 Err(e) => return Err(e),
263 }
264
265 Ok(_has_valid_content)
266 }
267
268 fn handle_line<F>(
270 &mut self,
271 raw_line: &str,
272 accumulated_response: &mut StreamingResponse,
273 on_chunk: &mut F,
274 ) -> Result<bool, StreamingError>
275 where
276 F: FnMut(&str) -> Result<(), StreamingError>,
277 {
278 let mut _has_valid_content = false;
279 let line = raw_line.trim_end_matches('\r');
280
281 if line.is_empty() {
282 match self.finalize_current_event(accumulated_response, on_chunk) {
283 Ok(valid) => {
284 if valid {
285 _has_valid_content = true;
286 }
287 }
288 Err(e) => return Err(e),
289 }
290 return Ok(_has_valid_content);
291 }
292
293 let trimmed = line.trim();
294
295 if trimmed.is_empty() {
296 return Ok(false);
297 }
298
299 if trimmed.starts_with(':') {
300 return Ok(false);
301 }
302
303 if trimmed.starts_with("event:") || trimmed.starts_with("id:") {
304 return Ok(false);
305 }
306
307 if trimmed.starts_with("data:") {
308 let data_segment = trimmed[5..].trim_start();
309 if data_segment == "[DONE]" {
310 match self.finalize_current_event(accumulated_response, on_chunk) {
311 Ok(valid) => {
312 if valid {
313 _has_valid_content = true;
314 }
315 }
316 Err(e) => return Err(e),
317 }
318 return Ok(_has_valid_content);
319 }
320
321 if !data_segment.is_empty() {
322 if !self.current_event_data.is_empty() {
323 self.current_event_data.push('\n');
324 }
325 self.current_event_data.push_str(data_segment);
326
327 match self.try_flush_current_event(accumulated_response, on_chunk) {
328 Ok(valid) => {
329 if valid {
330 _has_valid_content = true;
331 }
332 }
333 Err(e) => return Err(e),
334 }
335 }
336 return Ok(_has_valid_content);
337 }
338
339 if trimmed.starts_with('{') || trimmed.starts_with('[') {
340 if !self.current_event_data.is_empty() {
341 self.current_event_data.push('\n');
342 }
343 self.current_event_data.push_str(trimmed);
344 return Ok(false);
345 }
346
347 if !self.current_event_data.is_empty() {
348 self.current_event_data.push('\n');
349 }
350 self.current_event_data.push_str(trimmed);
351
352 Ok(false)
353 }
354
355 fn finalize_current_event<F>(
356 &mut self,
357 accumulated_response: &mut StreamingResponse,
358 on_chunk: &mut F,
359 ) -> Result<bool, StreamingError>
360 where
361 F: FnMut(&str) -> Result<(), StreamingError>,
362 {
363 if self.current_event_data.trim().is_empty() {
364 self.current_event_data.clear();
365 return Ok(false);
366 }
367
368 let event_data = std::mem::take(&mut self.current_event_data);
369 self.process_event(event_data, accumulated_response, on_chunk)
370 }
371
372 fn try_flush_current_event<F>(
373 &mut self,
374 accumulated_response: &mut StreamingResponse,
375 on_chunk: &mut F,
376 ) -> Result<bool, StreamingError>
377 where
378 F: FnMut(&str) -> Result<(), StreamingError>,
379 {
380 let trimmed = self.current_event_data.trim();
381 if trimmed.is_empty() {
382 return Ok(false);
383 }
384
385 match serde_json::from_str::<Value>(trimmed) {
386 Ok(parsed) => {
387 self.current_event_data.clear();
388 self.process_event_value(parsed, accumulated_response, on_chunk)
389 }
390 Err(parse_err) => {
391 if parse_err.is_eof() {
392 return Ok(false);
393 }
394
395 Err(StreamingError::ParseError {
396 message: format!("Failed to parse streaming JSON: {}", parse_err),
397 raw_response: trimmed.to_string(),
398 })
399 }
400 }
401 }
402
403 fn process_event<F>(
404 &mut self,
405 event_data: String,
406 accumulated_response: &mut StreamingResponse,
407 on_chunk: &mut F,
408 ) -> Result<bool, StreamingError>
409 where
410 F: FnMut(&str) -> Result<(), StreamingError>,
411 {
412 let trimmed = event_data.trim();
413
414 if trimmed.is_empty() {
415 return Ok(false);
416 }
417
418 match serde_json::from_str::<Value>(trimmed) {
419 Ok(parsed) => self.process_event_value(parsed, accumulated_response, on_chunk),
420 Err(parse_err) => {
421 if parse_err.is_eof() {
422 self.current_event_data = trimmed.to_string();
423 return Ok(false);
424 }
425
426 Err(StreamingError::ParseError {
427 message: format!("Failed to parse streaming JSON: {}", parse_err),
428 raw_response: trimmed.to_string(),
429 })
430 }
431 }
432 }
433
434 fn append_text_candidate(&mut self, accumulated_response: &mut StreamingResponse, text: &str) {
435 if text.is_empty() {
436 return;
437 }
438
439 if let Some(last_candidate) = accumulated_response.candidates.last_mut() {
440 Self::merge_parts(
441 &mut last_candidate.content.parts,
442 vec![Part::Text {
443 text: text.to_string(),
444 }],
445 );
446 return;
447 }
448
449 let index = accumulated_response.candidates.len();
450
451 accumulated_response.candidates.push(StreamingCandidate {
452 content: Content {
453 role: "model".to_string(),
454 parts: vec![Part::Text {
455 text: text.to_string(),
456 }],
457 },
458 finish_reason: None,
459 index: Some(index),
460 });
461 }
462
463 fn process_candidate<F>(
465 &self,
466 candidate: &StreamingCandidate,
467 on_chunk: &mut F,
468 ) -> Result<bool, StreamingError>
469 where
470 F: FnMut(&str) -> Result<(), StreamingError>,
471 {
472 let mut _has_valid_content = false;
473
474 for part in &candidate.content.parts {
476 match part {
477 Part::Text { text } => {
478 if !text.trim().is_empty() {
479 on_chunk(text)?;
480 _has_valid_content = true;
481 }
482 }
483 Part::FunctionCall { .. } => {
484 _has_valid_content = true;
486 }
487 Part::FunctionResponse { .. } => {
488 _has_valid_content = true;
489 }
490 }
491 }
492
493 Ok(_has_valid_content)
494 }
495
496 fn process_event_value<F>(
497 &mut self,
498 value: Value,
499 accumulated_response: &mut StreamingResponse,
500 on_chunk: &mut F,
501 ) -> Result<bool, StreamingError>
502 where
503 F: FnMut(&str) -> Result<(), StreamingError>,
504 {
505 match value {
506 Value::Array(items) => {
507 let mut has_valid = false;
508 for item in items {
509 if self.process_event_value(item, accumulated_response, on_chunk)? {
510 has_valid = true;
511 }
512 }
513 Ok(has_valid)
514 }
515 Value::Object(map) => {
516 if let Some(error_value) = map.get("error") {
517 let message = error_value
518 .get("message")
519 .and_then(Value::as_str)
520 .unwrap_or("Gemini streaming error")
521 .to_string();
522 let code = error_value
523 .get("code")
524 .and_then(Value::as_i64)
525 .unwrap_or(500) as u16;
526 return Err(StreamingError::ApiError {
527 status_code: code,
528 message,
529 is_retryable: code == 429,
530 });
531 }
532
533 if let Some(usage) = map.get("usageMetadata") {
534 accumulated_response.usage_metadata = Some(usage.clone());
535 }
536
537 let mut has_valid = false;
538
539 if let Some(candidates_value) = map.get("candidates") {
540 let candidate_values: Vec<Value> = match candidates_value {
541 Value::Array(items) => items.clone(),
542 Value::Object(_) => vec![candidates_value.clone()],
543 _ => Vec::new(),
544 };
545
546 for candidate_value in candidate_values {
547 match serde_json::from_value::<StreamingCandidate>(candidate_value.clone())
548 {
549 Ok(candidate) => {
550 if self.process_candidate(&candidate, on_chunk)? {
551 has_valid = true;
552 }
553 self.merge_candidate(accumulated_response, candidate);
554 }
555 Err(err) => {
556 if let Some(text) = Self::extract_text_from_value(&candidate_value)
557 {
558 if !text.trim().is_empty() {
559 on_chunk(&text)?;
560 self.append_text_candidate(accumulated_response, &text);
561 has_valid = true;
562 }
563 } else {
564 return Err(StreamingError::ParseError {
565 message: format!("Failed to parse candidate: {}", err),
566 raw_response: candidate_value.to_string(),
567 });
568 }
569 }
570 }
571 }
572 }
573
574 if let Some(text_value) = map.get("text").and_then(Value::as_str) {
575 if !text_value.trim().is_empty() {
576 on_chunk(text_value)?;
577 self.append_text_candidate(accumulated_response, text_value);
578 has_valid = true;
579 }
580 }
581
582 Ok(has_valid)
583 }
584 Value::String(text) => {
585 if text.trim().is_empty() {
586 Ok(false)
587 } else {
588 on_chunk(&text)?;
589 self.append_text_candidate(accumulated_response, &text);
590 Ok(true)
591 }
592 }
593 _ => Ok(false),
594 }
595 }
596
597 fn merge_candidate(
598 &mut self,
599 accumulated_response: &mut StreamingResponse,
600 mut candidate: StreamingCandidate,
601 ) {
602 let index = candidate
603 .index
604 .unwrap_or_else(|| accumulated_response.candidates.len());
605
606 if let Some(existing) = accumulated_response
607 .candidates
608 .iter_mut()
609 .find(|existing| existing.index.unwrap_or(index) == index)
610 {
611 if existing.content.role.is_empty() {
612 existing.content.role = candidate.content.role.clone();
613 }
614
615 Self::merge_parts(&mut existing.content.parts, candidate.content.parts);
616
617 if candidate.finish_reason.is_some() {
618 existing.finish_reason = candidate.finish_reason;
619 }
620 } else {
621 candidate.index = Some(index);
622 accumulated_response.candidates.push(candidate);
623 }
624 }
625
626 fn merge_parts(target: &mut Vec<Part>, source_parts: Vec<Part>) {
627 if target.is_empty() {
628 *target = source_parts;
629 return;
630 }
631
632 for part in source_parts {
633 match (target.last_mut(), &part) {
634 (Some(Part::Text { text: existing }), Part::Text { text: new_text }) => {
635 existing.push_str(new_text);
636 }
637 _ => target.push(part),
638 }
639 }
640 }
641
642 fn extract_text_from_value(value: &Value) -> Option<String> {
643 match value {
644 Value::String(text) => {
645 if text.trim().is_empty() {
646 None
647 } else {
648 Some(text.clone())
649 }
650 }
651 Value::Array(items) => {
652 let mut collected = String::new();
653 for item in items {
654 if let Some(fragment) = Self::extract_text_from_value(item) {
655 collected.push_str(&fragment);
656 }
657 }
658 if collected.is_empty() {
659 None
660 } else {
661 Some(collected)
662 }
663 }
664 Value::Object(map) => {
665 if let Some(text) = map.get("text").and_then(Value::as_str) {
666 if !text.trim().is_empty() {
667 return Some(text.to_string());
668 }
669 }
670
671 if let Some(parts) = map.get("parts").and_then(Value::as_array) {
672 if let Some(parts_text) =
673 Self::extract_text_from_value(&Value::Array(parts.clone()))
674 {
675 return Some(parts_text);
676 }
677 }
678
679 for nested in map.values() {
680 if let Some(nested_text) = Self::extract_text_from_value(nested) {
681 if !nested_text.trim().is_empty() {
682 return Some(nested_text);
683 }
684 }
685 }
686
687 None
688 }
689 _ => None,
690 }
691 }
692
693 pub fn metrics(&self) -> &StreamingMetrics {
695 &self.metrics
696 }
697
698 pub fn reset_metrics(&mut self) {
700 self.metrics = StreamingMetrics::default();
701 }
702}
703
704impl Default for StreamingProcessor {
705 fn default() -> Self {
706 Self::new()
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713
714 #[test]
715 fn test_streaming_processor_creation() {
716 let processor = StreamingProcessor::new();
717 assert_eq!(processor.metrics().total_requests, 0);
718 }
719
720 #[test]
721 fn test_streaming_processor_with_config() {
722 use std::time::Duration;
723
724 let config = StreamingConfig {
725 chunk_timeout: Duration::from_secs(10),
726 first_chunk_timeout: Duration::from_secs(30),
727 buffer_size: 512,
728 };
729
730 let processor = StreamingProcessor::with_config(config);
731 assert_eq!(processor.metrics().total_requests, 0);
732 }
733
734 #[test]
735 fn test_streaming_config_default() {
736 let config = StreamingConfig::default();
737 assert_eq!(config.buffer_size, 1024);
738 }
739
740 #[test]
741 fn test_handles_back_to_back_data_lines_without_blank_lines() {
742 let mut processor = StreamingProcessor::new();
743 let mut accumulated = StreamingResponse {
744 candidates: Vec::new(),
745 usage_metadata: None,
746 };
747 let mut received_chunks: Vec<String> = Vec::new();
748 let mut buffer = String::from(
749 "data: {\"candidates\":[{\"index\":0,\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"Hello\"}]}}]}\n",
750 );
751 buffer.push_str(
752 "data: {\"candidates\":[{\"index\":0,\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\" world\"}]}}]}\n",
753 );
754
755 {
756 let mut on_chunk = |chunk: &str| {
757 received_chunks.push(chunk.to_string());
758 Ok(())
759 };
760 let has_valid = processor
761 .process_buffer(&mut buffer, &mut accumulated, &mut on_chunk)
762 .expect("processing should succeed");
763 assert!(has_valid);
764 }
765
766 assert_eq!(received_chunks, vec!["Hello", " world"]);
767 assert_eq!(accumulated.candidates.len(), 1);
768 let combined = match &accumulated.candidates[0].content.parts[0] {
769 Part::Text { text } => text.clone(),
770 _ => String::new(),
771 };
772 assert_eq!(combined, "Hello world");
773 }
774}