1use crate::contracts::thread::ToolCall;
12use genai::chat::{ChatStreamEvent, Usage};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::collections::HashMap;
16use tirea_contract::runtime::inference::StopReason;
17use tirea_contract::{StreamResult, TokenUsage};
18
19pub(crate) fn token_usage_from_genai(u: &Usage) -> TokenUsage {
20 let (cache_read, cache_creation) = u
21 .prompt_tokens_details
22 .as_ref()
23 .map_or((None, None), |d| (d.cached_tokens, d.cache_creation_tokens));
24 TokenUsage {
25 prompt_tokens: u.prompt_tokens,
26 completion_tokens: u.completion_tokens,
27 total_tokens: u.total_tokens,
28 cache_read_tokens: cache_read,
29 cache_creation_tokens: cache_creation,
30 }
31}
32
33pub(crate) fn map_genai_stop_reason(reason: &genai::chat::StopReason) -> Option<StopReason> {
34 match reason {
35 genai::chat::StopReason::Completed(_) => Some(StopReason::EndTurn),
36 genai::chat::StopReason::MaxTokens(_) => Some(StopReason::MaxTokens),
37 genai::chat::StopReason::ToolCall(_) => Some(StopReason::ToolUse),
38 genai::chat::StopReason::StopSequence(_) => Some(StopReason::StopSequence),
39 genai::chat::StopReason::ContentFilter(_) | genai::chat::StopReason::Other(_) => None,
40 }
41}
42
43#[derive(Debug, Clone)]
45struct PartialToolCall {
46 id: String,
47 name: String,
48 arguments: String,
49}
50
51#[derive(Debug, Default)]
55pub struct StreamCollector {
56 text: String,
57 tool_calls: HashMap<String, PartialToolCall>,
58 tool_call_order: Vec<String>,
59 usage: Option<Usage>,
60 stop_reason: Option<genai::chat::StopReason>,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub(crate) enum StreamRecoveryCheckpoint {
65 NoPayload,
66 PartialText(String),
67 ToolCallObserved,
68}
69
70impl StreamCollector {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn into_partial_text(self) -> String {
81 self.text
82 }
83
84 pub(crate) fn into_recovery_checkpoint(self) -> StreamRecoveryCheckpoint {
86 if !self.tool_calls.is_empty() {
87 StreamRecoveryCheckpoint::ToolCallObserved
88 } else if self.text.is_empty() {
89 StreamRecoveryCheckpoint::NoPayload
90 } else {
91 StreamRecoveryCheckpoint::PartialText(self.text)
92 }
93 }
94
95 pub fn process(&mut self, event: ChatStreamEvent) -> Option<StreamOutput> {
100 match event {
101 ChatStreamEvent::Chunk(chunk) => {
102 if !chunk.content.is_empty() {
104 self.text.push_str(&chunk.content);
105 return Some(StreamOutput::TextDelta(chunk.content));
106 }
107 None
108 }
109 ChatStreamEvent::ReasoningChunk(chunk) => {
110 if !chunk.content.is_empty() {
111 return Some(StreamOutput::ReasoningDelta(chunk.content));
112 }
113 None
114 }
115 ChatStreamEvent::ThoughtSignatureChunk(chunk) => {
116 if !chunk.content.is_empty() {
117 return Some(StreamOutput::ReasoningEncryptedValue(chunk.content));
118 }
119 None
120 }
121 ChatStreamEvent::ToolCallChunk(tool_chunk) => {
122 let call_id = tool_chunk.tool_call.call_id.clone();
123
124 let partial = match self.tool_calls.entry(call_id.clone()) {
126 std::collections::hash_map::Entry::Occupied(e) => e.into_mut(),
127 std::collections::hash_map::Entry::Vacant(e) => {
128 self.tool_call_order.push(call_id.clone());
129 e.insert(PartialToolCall {
130 id: call_id.clone(),
131 name: String::new(),
132 arguments: String::new(),
133 })
134 }
135 };
136
137 let mut output = None;
138
139 if !tool_chunk.tool_call.fn_name.is_empty() && partial.name.is_empty() {
141 partial.name = tool_chunk.tool_call.fn_name.clone();
142 output = Some(StreamOutput::ToolCallStart {
143 id: call_id.clone(),
144 name: partial.name.clone(),
145 });
146 }
147
148 let args_str = match &tool_chunk.tool_call.fn_arguments {
155 Value::String(s) if !s.is_empty() => s.clone(),
156 Value::Null | Value::String(_) => String::new(),
157 other => other.to_string(),
158 };
159 if !args_str.is_empty() {
160 let delta = if args_str.len() > partial.arguments.len()
162 && args_str.starts_with(&partial.arguments)
163 {
164 args_str[partial.arguments.len()..].to_string()
165 } else {
166 args_str.clone()
167 };
168 partial.arguments = args_str;
169 if !delta.is_empty() && output.is_none() {
171 output = Some(StreamOutput::ToolCallDelta {
172 id: call_id,
173 args_delta: delta,
174 });
175 }
176 }
177
178 output
179 }
180 ChatStreamEvent::End(end) => {
181 self.stop_reason = end.captured_stop_reason.clone();
182 if let Some(tool_calls) = end.captured_tool_calls() {
187 for tc in tool_calls {
188 let end_args = match &tc.fn_arguments {
190 Value::String(s) if !s.is_empty() => s.clone(),
191 Value::Null | Value::String(_) => String::new(),
192 other => other.to_string(),
193 };
194 match self.tool_calls.entry(tc.call_id.clone()) {
195 std::collections::hash_map::Entry::Occupied(mut e) => {
196 let partial = e.get_mut();
197 if partial.name.is_empty() {
198 partial.name = tc.fn_name.clone();
199 }
200 if !end_args.is_empty() {
202 partial.arguments = end_args;
203 }
204 }
205 std::collections::hash_map::Entry::Vacant(e) => {
206 self.tool_call_order.push(tc.call_id.clone());
207 e.insert(PartialToolCall {
208 id: tc.call_id.clone(),
209 name: tc.fn_name.clone(),
210 arguments: end_args,
211 });
212 }
213 }
214 }
215 }
216 self.usage = end.captured_usage;
218 None
219 }
220 _ => None,
221 }
222 }
223
224 pub fn finish(self, max_output_tokens: Option<u32>) -> StreamResult {
231 let mut remaining = self.tool_calls;
232 let mut tool_calls: Vec<ToolCall> = Vec::with_capacity(self.tool_call_order.len());
233
234 for call_id in self.tool_call_order {
235 let Some(p) = remaining.remove(&call_id) else {
236 continue;
237 };
238 if p.name.is_empty() {
239 continue;
240 }
241 let arguments = serde_json::from_str(&p.arguments).unwrap_or(Value::Null);
242 if arguments.is_null() && !p.arguments.is_empty() {
244 continue;
245 }
246 tool_calls.push(ToolCall::new(p.id, p.name, arguments));
247 }
248
249 let usage = self.usage.as_ref().map(token_usage_from_genai);
250 let explicit_stop_reason = self.stop_reason.as_ref().and_then(map_genai_stop_reason);
251 let mut stop_reason = explicit_stop_reason
252 .or_else(|| Self::infer_stop_reason(&tool_calls, &usage, max_output_tokens));
253
254 if matches!(
258 stop_reason,
259 Some(StopReason::MaxTokens) | Some(StopReason::ToolUse)
260 ) {
261 if let (Some(u), Some(max)) = (&usage, max_output_tokens) {
262 if u.completion_tokens == Some(max as i32) {
263 if let Some(last) = tool_calls.last() {
264 if last.arguments.is_null() {
265 tool_calls.pop();
266 stop_reason = explicit_stop_reason.or_else(|| {
269 Self::infer_stop_reason(&tool_calls, &usage, max_output_tokens)
270 });
271 }
272 }
273 }
274 }
275 }
276
277 StreamResult {
278 text: self.text,
279 tool_calls,
280 usage,
281 stop_reason,
282 }
283 }
284
285 fn infer_stop_reason(
292 tool_calls: &[ToolCall],
293 usage: &Option<TokenUsage>,
294 max_output_tokens: Option<u32>,
295 ) -> Option<StopReason> {
296 if !tool_calls.is_empty() {
297 return Some(StopReason::ToolUse);
298 }
299 if let (Some(u), Some(max)) = (usage, max_output_tokens) {
300 if u.completion_tokens == Some(max as i32) {
301 return Some(StopReason::MaxTokens);
302 }
303 }
304 Some(StopReason::EndTurn)
305 }
306
307 pub fn text(&self) -> &str {
309 &self.text
310 }
311
312 pub fn has_tool_calls(&self) -> bool {
314 !self.tool_calls.is_empty()
315 }
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
320#[serde(tag = "type", rename_all = "snake_case")]
321pub enum StreamOutput {
322 TextDelta(String),
324 ReasoningDelta(String),
326 ReasoningEncryptedValue(String),
328 ToolCallStart { id: String, name: String },
330 ToolCallDelta { id: String, args_delta: String },
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::contracts::runtime::tool_call::ToolResult;
338 use crate::contracts::AgentEvent;
339 use crate::contracts::TerminationReason;
340 use serde_json::json;
341
342 #[test]
343 fn test_extract_response_with_value() {
344 let result = Some(json!({"response": "Hello world"}));
345 assert_eq!(AgentEvent::extract_response(&result), "Hello world");
346 }
347
348 #[test]
349 fn test_extract_response_none() {
350 assert_eq!(AgentEvent::extract_response(&None), "");
351 }
352
353 #[test]
354 fn test_extract_response_missing_key() {
355 let result = Some(json!({"other": "value"}));
356 assert_eq!(AgentEvent::extract_response(&result), "");
357 }
358
359 #[test]
360 fn test_extract_response_non_string() {
361 let result = Some(json!({"response": 42}));
362 assert_eq!(AgentEvent::extract_response(&result), "");
363 }
364
365 #[test]
366 fn test_stream_collector_new() {
367 let collector = StreamCollector::new();
368 assert!(collector.text().is_empty());
369 assert!(!collector.has_tool_calls());
370 }
371
372 #[test]
373 fn test_map_genai_stop_reason_known_values() {
374 use genai::chat::StopReason as GSR;
375 assert_eq!(
376 map_genai_stop_reason(&GSR::from("stop".to_string())),
377 Some(StopReason::EndTurn)
378 );
379 assert_eq!(
380 map_genai_stop_reason(&GSR::from("end_turn".to_string())),
381 Some(StopReason::EndTurn)
382 );
383 assert_eq!(
384 map_genai_stop_reason(&GSR::from("length".to_string())),
385 Some(StopReason::MaxTokens)
386 );
387 assert_eq!(
388 map_genai_stop_reason(&GSR::from("max_tokens".to_string())),
389 Some(StopReason::MaxTokens)
390 );
391 assert_eq!(
392 map_genai_stop_reason(&GSR::from("tool_calls".to_string())),
393 Some(StopReason::ToolUse)
394 );
395 assert_eq!(
396 map_genai_stop_reason(&GSR::from("stop_sequence".to_string())),
397 Some(StopReason::StopSequence)
398 );
399 }
400
401 #[test]
402 fn test_map_genai_stop_reason_unknown_value() {
403 use genai::chat::StopReason as GSR;
404 assert_eq!(
405 map_genai_stop_reason(&GSR::from("content_filter".to_string())),
406 None
407 );
408 }
409
410 #[test]
411 fn test_stream_collector_finish_prefers_explicit_stop_reason() {
412 let mut collector = StreamCollector::new();
413 collector.process(ChatStreamEvent::End(genai::chat::StreamEnd {
414 captured_usage: Some(Usage {
415 completion_tokens: Some(128),
416 ..Default::default()
417 }),
418 captured_stop_reason: Some(genai::chat::StopReason::from("stop_sequence".to_string())),
419 ..Default::default()
420 }));
421
422 let result = collector.finish(Some(128));
423 assert_eq!(result.stop_reason, Some(StopReason::StopSequence));
424 }
425
426 #[test]
427 fn test_stream_collector_finish_falls_back_when_explicit_stop_reason_unknown() {
428 let mut collector = StreamCollector::new();
429 collector.process(ChatStreamEvent::End(genai::chat::StreamEnd {
430 captured_usage: Some(Usage {
431 completion_tokens: Some(128),
432 ..Default::default()
433 }),
434 captured_stop_reason: Some(genai::chat::StopReason::from(
435 "unknown_stop_reason".to_string(),
436 )),
437 ..Default::default()
438 }));
439
440 let result = collector.finish(Some(128));
441 assert_eq!(result.stop_reason, Some(StopReason::MaxTokens));
442 }
443
444 #[test]
445 fn test_stream_collector_finish_empty() {
446 let collector = StreamCollector::new();
447 let result = collector.finish(None);
448
449 assert!(result.text.is_empty());
450 assert!(result.tool_calls.is_empty());
451 assert!(!result.needs_tools());
452 }
453
454 #[test]
455 fn test_stream_result_needs_tools() {
456 let result = StreamResult {
457 text: "Hello".to_string(),
458 tool_calls: vec![],
459 usage: None,
460 stop_reason: None,
461 };
462 assert!(!result.needs_tools());
463
464 let result_with_tools = StreamResult {
465 text: String::new(),
466 tool_calls: vec![ToolCall::new("id", "name", serde_json::json!({}))],
467 usage: None,
468 stop_reason: None,
469 };
470 assert!(result_with_tools.needs_tools());
471 }
472
473 #[test]
474 fn test_stream_output_variants() {
475 let text_delta = StreamOutput::TextDelta("Hello".to_string());
476 match text_delta {
477 StreamOutput::TextDelta(s) => assert_eq!(s, "Hello"),
478 _ => panic!("Expected TextDelta"),
479 }
480
481 let tool_start = StreamOutput::ToolCallStart {
482 id: "call_1".to_string(),
483 name: "search".to_string(),
484 };
485 match tool_start {
486 StreamOutput::ToolCallStart { id, name } => {
487 assert_eq!(id, "call_1");
488 assert_eq!(name, "search");
489 }
490 _ => panic!("Expected ToolCallStart"),
491 }
492
493 let tool_delta = StreamOutput::ToolCallDelta {
494 id: "call_1".to_string(),
495 args_delta: r#"{"query":"#.to_string(),
496 };
497 match tool_delta {
498 StreamOutput::ToolCallDelta { id, args_delta } => {
499 assert_eq!(id, "call_1");
500 assert!(args_delta.contains("query"));
501 }
502 _ => panic!("Expected ToolCallDelta"),
503 }
504
505 let reasoning_delta = StreamOutput::ReasoningDelta("analysis".to_string());
506 match reasoning_delta {
507 StreamOutput::ReasoningDelta(s) => assert_eq!(s, "analysis"),
508 _ => panic!("Expected ReasoningDelta"),
509 }
510
511 let reasoning_token = StreamOutput::ReasoningEncryptedValue("opaque".to_string());
512 match reasoning_token {
513 StreamOutput::ReasoningEncryptedValue(s) => assert_eq!(s, "opaque"),
514 _ => panic!("Expected ReasoningEncryptedValue"),
515 }
516 }
517
518 #[test]
519 fn test_agent_event_variants() {
520 let event = AgentEvent::TextDelta {
522 delta: "Hello".to_string(),
523 };
524 match event {
525 AgentEvent::TextDelta { delta } => assert_eq!(delta, "Hello"),
526 _ => panic!("Expected TextDelta"),
527 }
528
529 let event = AgentEvent::ReasoningDelta {
530 delta: "thinking".to_string(),
531 };
532 match event {
533 AgentEvent::ReasoningDelta { delta } => assert_eq!(delta, "thinking"),
534 _ => panic!("Expected ReasoningDelta"),
535 }
536
537 let event = AgentEvent::ToolCallStart {
539 id: "call_1".to_string(),
540 name: "search".to_string(),
541 };
542 if let AgentEvent::ToolCallStart { id, name } = event {
543 assert_eq!(id, "call_1");
544 assert_eq!(name, "search");
545 }
546
547 let event = AgentEvent::ToolCallDelta {
549 id: "call_1".to_string(),
550 args_delta: "{}".to_string(),
551 };
552 if let AgentEvent::ToolCallDelta { id, .. } = event {
553 assert_eq!(id, "call_1");
554 }
555
556 let result = ToolResult::success("test", json!({"value": 42}));
558 let event = AgentEvent::ToolCallDone {
559 id: "call_1".to_string(),
560 result: result.clone(),
561 patch: None,
562 message_id: String::new(),
563 };
564 if let AgentEvent::ToolCallDone {
565 id,
566 result: r,
567 patch,
568 ..
569 } = event
570 {
571 assert_eq!(id, "call_1");
572 assert!(r.is_success());
573 assert!(patch.is_none());
574 }
575
576 let event = AgentEvent::RunFinish {
578 thread_id: "t1".to_string(),
579 run_id: "r1".to_string(),
580 result: Some(json!({"response": "Final response"})),
581 termination: crate::contracts::TerminationReason::NaturalEnd,
582 };
583 if let AgentEvent::RunFinish { result, .. } = &event {
584 assert_eq!(AgentEvent::extract_response(result), "Final response");
585 }
586
587 let event = AgentEvent::ActivitySnapshot {
589 message_id: "activity_1".to_string(),
590 activity_type: "progress".to_string(),
591 content: json!({"progress": 0.5}),
592 replace: Some(true),
593 };
594 if let AgentEvent::ActivitySnapshot {
595 message_id,
596 activity_type,
597 content,
598 replace,
599 } = event
600 {
601 assert_eq!(message_id, "activity_1");
602 assert_eq!(activity_type, "progress");
603 assert_eq!(content["progress"], 0.5);
604 assert_eq!(replace, Some(true));
605 }
606
607 let event = AgentEvent::ActivityDelta {
609 message_id: "activity_1".to_string(),
610 activity_type: "progress".to_string(),
611 patch: vec![json!({"op": "replace", "path": "/progress", "value": 0.75})],
612 };
613 if let AgentEvent::ActivityDelta {
614 message_id,
615 activity_type,
616 patch,
617 } = event
618 {
619 assert_eq!(message_id, "activity_1");
620 assert_eq!(activity_type, "progress");
621 assert_eq!(patch.len(), 1);
622 }
623
624 let event = AgentEvent::Error {
626 message: "Something went wrong".to_string(),
627 code: None,
628 };
629 if let AgentEvent::Error { message, .. } = event {
630 assert!(message.contains("wrong"));
631 }
632 }
633
634 #[test]
635 fn test_stream_result_with_multiple_tool_calls() {
636 let result = StreamResult {
637 text: "I'll call multiple tools".to_string(),
638 tool_calls: vec![
639 ToolCall::new("call_1", "search", json!({"q": "rust"})),
640 ToolCall::new("call_2", "calculate", json!({"expr": "1+1"})),
641 ToolCall::new("call_3", "format", json!({"text": "hello"})),
642 ],
643 usage: None,
644 stop_reason: None,
645 };
646
647 assert!(result.needs_tools());
648 assert_eq!(result.tool_calls.len(), 3);
649 assert_eq!(result.tool_calls[0].name, "search");
650 assert_eq!(result.tool_calls[1].name, "calculate");
651 assert_eq!(result.tool_calls[2].name, "format");
652 }
653
654 #[test]
655 fn test_stream_result_text_only() {
656 let result = StreamResult {
657 text: "This is a long response without any tool calls. It just contains text."
658 .to_string(),
659 tool_calls: vec![],
660 usage: None,
661 stop_reason: None,
662 };
663
664 assert!(!result.needs_tools());
665 assert!(result.text.len() > 50);
666 }
667
668 #[test]
669 fn test_tool_call_with_complex_arguments() {
670 let call = ToolCall::new(
671 "call_complex",
672 "api_request",
673 json!({
674 "method": "POST",
675 "url": "https://api.example.com/data",
676 "headers": {
677 "Content-Type": "application/json",
678 "Authorization": "Bearer token"
679 },
680 "body": {
681 "items": [1, 2, 3],
682 "nested": {
683 "deep": true
684 }
685 }
686 }),
687 );
688
689 assert_eq!(call.id, "call_complex");
690 assert_eq!(call.name, "api_request");
691 assert_eq!(call.arguments["method"], "POST");
692 assert!(call.arguments["headers"]["Content-Type"]
693 .as_str()
694 .unwrap()
695 .contains("json"));
696 }
697
698 #[test]
699 fn test_agent_event_done_with_patch() {
700 use tirea_state::{path, Op, Patch, TrackedPatch};
701
702 let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("value"), json!(42))));
703
704 let event = AgentEvent::ToolCallDone {
705 id: "call_1".to_string(),
706 result: ToolResult::success("test", json!({})),
707 patch: Some(patch.clone()),
708 message_id: String::new(),
709 };
710
711 if let AgentEvent::ToolCallDone { patch: p, .. } = event {
712 assert!(p.is_some());
713 let p = p.unwrap();
714 assert!(!p.patch().is_empty());
715 }
716 }
717
718 #[test]
719 fn test_stream_output_debug() {
720 let output = StreamOutput::TextDelta("test".to_string());
721 let debug_str = format!("{:?}", output);
722 assert!(debug_str.contains("TextDelta"));
723 assert!(debug_str.contains("test"));
724 }
725
726 #[test]
727 fn test_agent_event_debug() {
728 let event = AgentEvent::Error {
729 message: "error message".to_string(),
730 code: None,
731 };
732 let debug_str = format!("{:?}", event);
733 assert!(debug_str.contains("Error"));
734 assert!(debug_str.contains("error message"));
735 }
736
737 #[test]
738 fn test_stream_result_clone() {
739 let result = StreamResult {
740 text: "Hello".to_string(),
741 tool_calls: vec![ToolCall::new("1", "test", json!({}))],
742 usage: None,
743 stop_reason: None,
744 };
745
746 let cloned = result.clone();
747 assert_eq!(cloned.text, result.text);
748 assert_eq!(cloned.tool_calls.len(), result.tool_calls.len());
749 }
750
751 use genai::chat::{StreamChunk, StreamEnd, ToolChunk};
753
754 #[test]
755 fn test_stream_collector_process_text_chunk() {
756 let mut collector = StreamCollector::new();
757
758 let chunk = ChatStreamEvent::Chunk(StreamChunk {
760 content: "Hello ".to_string(),
761 });
762 let output = collector.process(chunk);
763
764 assert!(output.is_some());
765 if let Some(StreamOutput::TextDelta(delta)) = output {
766 assert_eq!(delta, "Hello ");
767 } else {
768 panic!("Expected TextDelta");
769 }
770
771 assert_eq!(collector.text(), "Hello ");
772 }
773
774 #[test]
775 fn test_stream_collector_process_reasoning_chunk() {
776 let mut collector = StreamCollector::new();
777
778 let chunk = ChatStreamEvent::ReasoningChunk(StreamChunk {
779 content: "chain".to_string(),
780 });
781 let output = collector.process(chunk);
782
783 if let Some(StreamOutput::ReasoningDelta(delta)) = output {
784 assert_eq!(delta, "chain");
785 } else {
786 panic!("Expected ReasoningDelta");
787 }
788 }
789
790 #[test]
791 fn test_stream_collector_process_thought_signature_chunk() {
792 let mut collector = StreamCollector::new();
793
794 let chunk = ChatStreamEvent::ThoughtSignatureChunk(StreamChunk {
795 content: "opaque-token".to_string(),
796 });
797 let output = collector.process(chunk);
798
799 if let Some(StreamOutput::ReasoningEncryptedValue(value)) = output {
800 assert_eq!(value, "opaque-token");
801 } else {
802 panic!("Expected ReasoningEncryptedValue");
803 }
804 }
805
806 #[test]
807 fn test_stream_collector_process_multiple_text_chunks() {
808 let mut collector = StreamCollector::new();
809
810 let chunks = vec!["Hello ", "world", "!"];
812 for text in &chunks {
813 let chunk = ChatStreamEvent::Chunk(StreamChunk {
814 content: text.to_string(),
815 });
816 collector.process(chunk);
817 }
818
819 assert_eq!(collector.text(), "Hello world!");
820
821 let result = collector.finish(None);
822 assert_eq!(result.text, "Hello world!");
823 assert!(!result.needs_tools());
824 }
825
826 #[test]
827 fn test_stream_collector_process_empty_chunk() {
828 let mut collector = StreamCollector::new();
829
830 let chunk = ChatStreamEvent::Chunk(StreamChunk {
831 content: String::new(),
832 });
833 let output = collector.process(chunk);
834
835 assert!(output.is_none());
837 assert!(collector.text().is_empty());
838 }
839
840 #[test]
841 fn test_stream_collector_process_tool_call_start() {
842 let mut collector = StreamCollector::new();
843
844 let tool_call = genai::chat::ToolCall {
845 call_id: "call_123".to_string(),
846 fn_name: "search".to_string(),
847 fn_arguments: json!(null),
848 thought_signatures: None,
849 };
850 let chunk = ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call });
851 let output = collector.process(chunk);
852
853 assert!(output.is_some());
854 if let Some(StreamOutput::ToolCallStart { id, name }) = output {
855 assert_eq!(id, "call_123");
856 assert_eq!(name, "search");
857 } else {
858 panic!("Expected ToolCallStart");
859 }
860
861 assert!(collector.has_tool_calls());
862 }
863
864 #[test]
865 fn test_stream_collector_process_tool_call_with_arguments() {
866 let mut collector = StreamCollector::new();
867
868 let tool_call1 = genai::chat::ToolCall {
870 call_id: "call_abc".to_string(),
871 fn_name: "calculator".to_string(),
872 fn_arguments: json!(null),
873 thought_signatures: None,
874 };
875 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
876 tool_call: tool_call1,
877 }));
878
879 let tool_call2 = genai::chat::ToolCall {
881 call_id: "call_abc".to_string(),
882 fn_name: String::new(), fn_arguments: json!({"expr": "1+1"}),
884 thought_signatures: None,
885 };
886 let output = collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
887 tool_call: tool_call2,
888 }));
889
890 assert!(output.is_some());
891 if let Some(StreamOutput::ToolCallDelta { id, args_delta }) = output {
892 assert_eq!(id, "call_abc");
893 assert!(args_delta.contains("expr"));
894 }
895
896 let result = collector.finish(None);
897 assert!(result.needs_tools());
898 assert_eq!(result.tool_calls.len(), 1);
899 assert_eq!(result.tool_calls[0].name, "calculator");
900 }
901
902 #[test]
903 fn test_stream_collector_single_chunk_with_name_and_args_keeps_tool_start() {
904 let mut collector = StreamCollector::new();
905
906 let tool_call = genai::chat::ToolCall {
907 call_id: "call_single".to_string(),
908 fn_name: "search".to_string(),
909 fn_arguments: Value::String(r#"{"q":"rust"}"#.to_string()),
910 thought_signatures: None,
911 };
912 let output = collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call }));
913
914 assert!(
915 matches!(output, Some(StreamOutput::ToolCallStart { .. })),
916 "tool start should not be lost when name+args arrive in one chunk; got: {output:?}"
917 );
918
919 let result = collector.finish(None);
920 assert_eq!(result.tool_calls.len(), 1);
921 assert_eq!(result.tool_calls[0].id, "call_single");
922 assert_eq!(result.tool_calls[0].name, "search");
923 assert_eq!(result.tool_calls[0].arguments, json!({"q":"rust"}));
924 }
925
926 #[test]
927 fn test_stream_collector_preserves_tool_call_arrival_order() {
928 let mut collector = StreamCollector::new();
929 let call_ids = vec![
930 "call_7", "call_3", "call_1", "call_9", "call_2", "call_8", "call_4", "call_6",
931 ];
932
933 for (idx, call_id) in call_ids.iter().enumerate() {
934 let tool_call = genai::chat::ToolCall {
935 call_id: (*call_id).to_string(),
936 fn_name: format!("tool_{idx}"),
937 fn_arguments: Value::Null,
938 thought_signatures: None,
939 };
940 let _ = collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call }));
941 }
942
943 let result = collector.finish(None);
944 let got: Vec<String> = result.tool_calls.into_iter().map(|c| c.id).collect();
945 let expected: Vec<String> = call_ids.into_iter().map(str::to_string).collect();
946
947 assert_eq!(
948 got, expected,
949 "tool_calls should preserve model-emitted order"
950 );
951 }
952
953 #[test]
954 fn test_stream_collector_process_multiple_tool_calls() {
955 let mut collector = StreamCollector::new();
956
957 let tc1 = genai::chat::ToolCall {
959 call_id: "call_1".to_string(),
960 fn_name: "search".to_string(),
961 fn_arguments: json!({"q": "rust"}),
962 thought_signatures: None,
963 };
964 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc1 }));
965
966 let tc2 = genai::chat::ToolCall {
968 call_id: "call_2".to_string(),
969 fn_name: "calculate".to_string(),
970 fn_arguments: json!({"expr": "2+2"}),
971 thought_signatures: None,
972 };
973 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc2 }));
974
975 let result = collector.finish(None);
976 assert_eq!(result.tool_calls.len(), 2);
977 }
978
979 #[test]
980 fn test_stream_collector_process_mixed_text_and_tools() {
981 let mut collector = StreamCollector::new();
982
983 collector.process(ChatStreamEvent::Chunk(StreamChunk {
985 content: "I'll search for that. ".to_string(),
986 }));
987
988 let tc = genai::chat::ToolCall {
990 call_id: "call_search".to_string(),
991 fn_name: "web_search".to_string(),
992 fn_arguments: json!({"query": "rust programming"}),
993 thought_signatures: None,
994 };
995 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
996
997 let result = collector.finish(None);
998 assert_eq!(result.text, "I'll search for that. ");
999 assert_eq!(result.tool_calls.len(), 1);
1000 assert_eq!(result.tool_calls[0].name, "web_search");
1001 }
1002
1003 #[test]
1004 fn test_stream_collector_process_start_event() {
1005 let mut collector = StreamCollector::new();
1006
1007 let output = collector.process(ChatStreamEvent::Start);
1008 assert!(output.is_none());
1009 assert!(collector.text().is_empty());
1010 }
1011
1012 #[test]
1013 fn test_stream_collector_process_end_event() {
1014 let mut collector = StreamCollector::new();
1015
1016 collector.process(ChatStreamEvent::Chunk(StreamChunk {
1018 content: "Hello".to_string(),
1019 }));
1020
1021 let end = StreamEnd::default();
1023 let output = collector.process(ChatStreamEvent::End(end));
1024
1025 assert!(output.is_none());
1026
1027 let result = collector.finish(None);
1028 assert_eq!(result.text, "Hello");
1029 }
1030
1031 #[test]
1032 fn test_stream_collector_has_tool_calls() {
1033 let mut collector = StreamCollector::new();
1034 assert!(!collector.has_tool_calls());
1035
1036 let tc = genai::chat::ToolCall {
1037 call_id: "call_1".to_string(),
1038 fn_name: "test".to_string(),
1039 fn_arguments: json!({}),
1040 thought_signatures: None,
1041 };
1042 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1043
1044 assert!(collector.has_tool_calls());
1045 }
1046
1047 #[test]
1048 fn test_stream_collector_text_accumulation() {
1049 let mut collector = StreamCollector::new();
1050
1051 let words = vec!["The ", "quick ", "brown ", "fox ", "jumps."];
1053 for word in words {
1054 collector.process(ChatStreamEvent::Chunk(StreamChunk {
1055 content: word.to_string(),
1056 }));
1057 }
1058
1059 assert_eq!(collector.text(), "The quick brown fox jumps.");
1060 }
1061
1062 #[test]
1063 fn test_stream_collector_tool_arguments_accumulation() {
1064 let mut collector = StreamCollector::new();
1067
1068 let tc1 = genai::chat::ToolCall {
1070 call_id: "call_1".to_string(),
1071 fn_name: "api".to_string(),
1072 fn_arguments: json!(null),
1073 thought_signatures: None,
1074 };
1075 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc1 }));
1076
1077 let tc2 = genai::chat::ToolCall {
1079 call_id: "call_1".to_string(),
1080 fn_name: String::new(),
1081 fn_arguments: Value::String("{\"url\":".to_string()),
1082 thought_signatures: None,
1083 };
1084 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc2 }));
1085
1086 let tc3 = genai::chat::ToolCall {
1087 call_id: "call_1".to_string(),
1088 fn_name: String::new(),
1089 fn_arguments: Value::String("{\"url\": \"https://example.com\"}".to_string()),
1090 thought_signatures: None,
1091 };
1092 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc3 }));
1093
1094 let result = collector.finish(None);
1095 assert_eq!(result.tool_calls.len(), 1);
1096 assert_eq!(result.tool_calls[0].name, "api");
1097 assert_eq!(
1098 result.tool_calls[0].arguments,
1099 json!({"url": "https://example.com"})
1100 );
1101 }
1102
1103 #[test]
1104 fn test_stream_collector_value_string_args_accumulation() {
1105 let mut collector = StreamCollector::new();
1108
1109 let tc1 = genai::chat::ToolCall {
1111 call_id: "call_1".to_string(),
1112 fn_name: "get_weather".to_string(),
1113 fn_arguments: Value::String(String::new()),
1114 thought_signatures: None,
1115 };
1116 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc1 }));
1117
1118 let tc2 = genai::chat::ToolCall {
1120 call_id: "call_1".to_string(),
1121 fn_name: String::new(),
1122 fn_arguments: Value::String("{\"city\":".to_string()),
1123 thought_signatures: None,
1124 };
1125 let output2 =
1126 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc2 }));
1127 assert!(matches!(
1128 output2,
1129 Some(StreamOutput::ToolCallDelta { ref args_delta, .. }) if args_delta == "{\"city\":"
1130 ));
1131
1132 let tc3 = genai::chat::ToolCall {
1133 call_id: "call_1".to_string(),
1134 fn_name: String::new(),
1135 fn_arguments: Value::String("{\"city\": \"San Francisco\"}".to_string()),
1136 thought_signatures: None,
1137 };
1138 let output3 =
1139 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc3 }));
1140 assert!(matches!(
1142 output3,
1143 Some(StreamOutput::ToolCallDelta { ref args_delta, .. }) if args_delta == " \"San Francisco\"}"
1144 ));
1145
1146 let result = collector.finish(None);
1147 assert_eq!(result.tool_calls.len(), 1);
1148 assert_eq!(result.tool_calls[0].name, "get_weather");
1149 assert_eq!(
1150 result.tool_calls[0].arguments,
1151 json!({"city": "San Francisco"})
1152 );
1153 }
1154
1155 #[test]
1156 fn test_stream_collector_finish_clears_state() {
1157 let mut collector = StreamCollector::new();
1158
1159 collector.process(ChatStreamEvent::Chunk(StreamChunk {
1160 content: "Test".to_string(),
1161 }));
1162
1163 let result1 = collector.finish(None);
1164 assert_eq!(result1.text, "Test");
1165
1166 }
1169
1170 #[test]
1179 fn test_agent_event_tool_call_ready() {
1180 let event = AgentEvent::ToolCallReady {
1181 id: "call_1".to_string(),
1182 name: "search".to_string(),
1183 arguments: json!({"query": "rust programming"}),
1184 };
1185 if let AgentEvent::ToolCallReady {
1186 id,
1187 name,
1188 arguments,
1189 } = event
1190 {
1191 assert_eq!(id, "call_1");
1192 assert_eq!(name, "search");
1193 assert_eq!(arguments["query"], "rust programming");
1194 } else {
1195 panic!("Expected ToolCallReady");
1196 }
1197 }
1198
1199 #[test]
1200 fn test_agent_event_step_start() {
1201 let event = AgentEvent::StepStart {
1202 message_id: String::new(),
1203 };
1204 assert!(matches!(event, AgentEvent::StepStart { .. }));
1205 }
1206
1207 #[test]
1208 fn test_agent_event_step_end() {
1209 let event = AgentEvent::StepEnd;
1210 assert!(matches!(event, AgentEvent::StepEnd));
1211 }
1212
1213 #[test]
1214 fn test_agent_event_run_finish_cancelled() {
1215 let event = AgentEvent::RunFinish {
1216 thread_id: "t1".to_string(),
1217 run_id: "r1".to_string(),
1218 result: None,
1219 termination: TerminationReason::Cancelled,
1220 };
1221 if let AgentEvent::RunFinish { termination, .. } = event {
1222 assert_eq!(termination, TerminationReason::Cancelled);
1223 } else {
1224 panic!("Expected RunFinish");
1225 }
1226 }
1227
1228 #[test]
1229 fn test_agent_event_serialization() {
1230 let event = AgentEvent::TextDelta {
1231 delta: "Hello".to_string(),
1232 };
1233 let json = serde_json::to_string(&event).unwrap();
1234 assert!(json.contains("\"type\":\"text_delta\""));
1235 assert!(json.contains("\"data\""));
1236 assert!(json.contains("text_delta"));
1237 assert!(json.contains("Hello"));
1238
1239 let event = AgentEvent::StepStart {
1240 message_id: String::new(),
1241 };
1242 let json = serde_json::to_string(&event).unwrap();
1243 assert!(json.contains("step_start"));
1244
1245 let event = AgentEvent::ActivitySnapshot {
1246 message_id: "activity_1".to_string(),
1247 activity_type: "progress".to_string(),
1248 content: json!({"progress": 1.0}),
1249 replace: Some(true),
1250 };
1251 let json = serde_json::to_string(&event).unwrap();
1252 assert!(json.contains("activity_snapshot"));
1253 assert!(json.contains("activity_1"));
1254 }
1255
1256 #[test]
1257 fn test_agent_event_deserialization() {
1258 let json = r#"{"type":"step_start"}"#;
1259 let event: AgentEvent = serde_json::from_str(json).unwrap();
1260 assert!(matches!(event, AgentEvent::StepStart { .. }));
1261
1262 let json = r#"{"type":"text_delta","data":{"delta":"Hello"}}"#;
1263 let event: AgentEvent = serde_json::from_str(json).unwrap();
1264 if let AgentEvent::TextDelta { delta } = event {
1265 assert_eq!(delta, "Hello");
1266 } else {
1267 panic!("Expected TextDelta");
1268 }
1269
1270 let json = r#"{"type":"activity_snapshot","data":{"message_id":"activity_1","activity_type":"progress","content":{"progress":0.3},"replace":true}}"#;
1271 let event: AgentEvent = serde_json::from_str(json).unwrap();
1272 if let AgentEvent::ActivitySnapshot {
1273 message_id,
1274 activity_type,
1275 content,
1276 replace,
1277 } = event
1278 {
1279 assert_eq!(message_id, "activity_1");
1280 assert_eq!(activity_type, "progress");
1281 assert_eq!(content["progress"], 0.3);
1282 assert_eq!(replace, Some(true));
1283 } else {
1284 panic!("Expected ActivitySnapshot");
1285 }
1286 }
1287
1288 #[test]
1297 fn test_stream_output_variants_creation() {
1298 let text_delta = StreamOutput::TextDelta("Hello".to_string());
1300 assert!(matches!(text_delta, StreamOutput::TextDelta(_)));
1301
1302 let tool_start = StreamOutput::ToolCallStart {
1303 id: "call_1".to_string(),
1304 name: "search".to_string(),
1305 };
1306 assert!(matches!(tool_start, StreamOutput::ToolCallStart { .. }));
1307
1308 let tool_delta = StreamOutput::ToolCallDelta {
1309 id: "call_1".to_string(),
1310 args_delta: "delta".to_string(),
1311 };
1312 assert!(matches!(tool_delta, StreamOutput::ToolCallDelta { .. }));
1313 }
1314
1315 #[test]
1316 fn test_stream_collector_text_and_has_tool_calls() {
1317 let collector = StreamCollector::new();
1318 assert!(!collector.has_tool_calls());
1319 assert_eq!(collector.text(), "");
1320 }
1321
1322 #[test]
1343 fn test_stream_collector_ghost_tool_call_filtered() {
1344 let mut collector = StreamCollector::new();
1346
1347 let ghost = genai::chat::ToolCall {
1349 call_id: "ghost_1".to_string(),
1350 fn_name: String::new(),
1351 fn_arguments: json!(null),
1352 thought_signatures: None,
1353 };
1354 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
1355 tool_call: ghost,
1356 }));
1357
1358 let real = genai::chat::ToolCall {
1360 call_id: "real_1".to_string(),
1361 fn_name: "search".to_string(),
1362 fn_arguments: Value::String(r#"{"q":"rust"}"#.to_string()),
1363 thought_signatures: None,
1364 };
1365 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
1366 tool_call: real,
1367 }));
1368
1369 let result = collector.finish(None);
1370 assert_eq!(result.tool_calls.len(), 1);
1372 assert_eq!(result.tool_calls[0].name, "search");
1373 }
1374
1375 #[test]
1376 fn test_stream_collector_invalid_json_arguments_dropped() {
1377 let mut collector = StreamCollector::new();
1378
1379 let tc = genai::chat::ToolCall {
1380 call_id: "call_1".to_string(),
1381 fn_name: "test".to_string(),
1382 fn_arguments: Value::String("not valid json {{".to_string()),
1383 thought_signatures: None,
1384 };
1385 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1386
1387 let result = collector.finish(None);
1388 assert_eq!(result.tool_calls.len(), 0);
1390 }
1391
1392 #[test]
1393 fn test_stream_collector_duplicate_accumulated_args_full_replace() {
1394 let mut collector = StreamCollector::new();
1395
1396 let tc1 = genai::chat::ToolCall {
1398 call_id: "call_1".to_string(),
1399 fn_name: "test".to_string(),
1400 fn_arguments: Value::String(r#"{"a":1}"#.to_string()),
1401 thought_signatures: None,
1402 };
1403 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc1 }));
1404
1405 let tc2 = genai::chat::ToolCall {
1408 call_id: "call_1".to_string(),
1409 fn_name: String::new(),
1410 fn_arguments: Value::String(r#"{"a":1}"#.to_string()),
1411 thought_signatures: None,
1412 };
1413 let output =
1414 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc2 }));
1415 match output {
1416 Some(StreamOutput::ToolCallDelta { id, args_delta }) => {
1417 assert_eq!(id, "call_1");
1418 assert_eq!(args_delta, r#"{"a":1}"#);
1419 }
1420 other => panic!("Expected ToolCallDelta, got {:?}", other),
1421 }
1422 }
1423
1424 #[test]
1425 fn test_stream_collector_end_event_captures_usage() {
1426 let mut collector = StreamCollector::new();
1427
1428 let end = StreamEnd {
1429 captured_usage: Some(Usage {
1430 prompt_tokens: Some(10),
1431 prompt_tokens_details: None,
1432 completion_tokens: Some(20),
1433 completion_tokens_details: None,
1434 total_tokens: Some(30),
1435 }),
1436 ..Default::default()
1437 };
1438 collector.process(ChatStreamEvent::End(end));
1439
1440 let result = collector.finish(None);
1441 assert!(result.usage.is_some());
1442 let usage = result.usage.unwrap();
1443 assert_eq!(usage.prompt_tokens, Some(10));
1444 assert_eq!(usage.completion_tokens, Some(20));
1445 assert_eq!(usage.total_tokens, Some(30));
1446 }
1447
1448 #[test]
1449 fn test_stream_collector_end_event_fills_missing_partial() {
1450 use genai::chat::MessageContent;
1452
1453 let mut collector = StreamCollector::new();
1454
1455 let end_tc = genai::chat::ToolCall {
1456 call_id: "end_call".to_string(),
1457 fn_name: "finalize".to_string(),
1458 fn_arguments: Value::String(r#"{"done":true}"#.to_string()),
1459 thought_signatures: None,
1460 };
1461 let end = StreamEnd {
1462 captured_content: Some(MessageContent::from_tool_calls(vec![end_tc])),
1463 ..Default::default()
1464 };
1465 collector.process(ChatStreamEvent::End(end));
1466
1467 let result = collector.finish(None);
1468 assert_eq!(result.tool_calls.len(), 1);
1469 assert_eq!(result.tool_calls[0].id, "end_call");
1470 assert_eq!(result.tool_calls[0].name, "finalize");
1471 assert_eq!(result.tool_calls[0].arguments, json!({"done": true}));
1472 }
1473
1474 #[test]
1475 fn test_stream_collector_end_event_overrides_partial_args() {
1476 use genai::chat::MessageContent;
1478
1479 let mut collector = StreamCollector::new();
1480
1481 let tc1 = genai::chat::ToolCall {
1483 call_id: "call_1".to_string(),
1484 fn_name: "api".to_string(),
1485 fn_arguments: Value::String(r#"{"partial":true"#.to_string()), thought_signatures: None,
1487 };
1488 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc1 }));
1489
1490 let end_tc = genai::chat::ToolCall {
1492 call_id: "call_1".to_string(),
1493 fn_name: String::new(), fn_arguments: Value::String(r#"{"complete":true}"#.to_string()),
1495 thought_signatures: None,
1496 };
1497 let end = StreamEnd {
1498 captured_content: Some(MessageContent::from_tool_calls(vec![end_tc])),
1499 ..Default::default()
1500 };
1501 collector.process(ChatStreamEvent::End(end));
1502
1503 let result = collector.finish(None);
1504 assert_eq!(result.tool_calls.len(), 1);
1505 assert_eq!(result.tool_calls[0].name, "api");
1506 assert_eq!(result.tool_calls[0].arguments, json!({"complete": true}));
1508 }
1509
1510 #[test]
1511 fn test_stream_collector_value_object_args() {
1512 let mut collector = StreamCollector::new();
1514
1515 let tc = genai::chat::ToolCall {
1516 call_id: "call_1".to_string(),
1517 fn_name: "test".to_string(),
1518 fn_arguments: json!({"key": "val"}), thought_signatures: None,
1520 };
1521 let output = collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1522
1523 assert!(output.is_some());
1529
1530 let result = collector.finish(None);
1531 assert_eq!(result.tool_calls.len(), 1);
1532 assert_eq!(result.tool_calls[0].arguments, json!({"key": "val"}));
1533 }
1534
1535 #[test]
1542 fn test_stream_collector_truncated_json_args() {
1543 let mut collector = StreamCollector::new();
1547
1548 let tc = genai::chat::ToolCall {
1549 call_id: "call_1".to_string(),
1550 fn_name: "search".to_string(),
1551 fn_arguments: Value::String(r#"{"url": "https://example.com"#.to_string()),
1552 thought_signatures: None,
1553 };
1554 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1555
1556 let result = collector.finish(None);
1557 assert_eq!(result.tool_calls.len(), 0);
1559 }
1560
1561 #[test]
1562 fn test_stream_collector_empty_json_args() {
1563 let mut collector = StreamCollector::new();
1565
1566 let tc = genai::chat::ToolCall {
1567 call_id: "call_1".to_string(),
1568 fn_name: "noop".to_string(),
1569 fn_arguments: Value::String(String::new()),
1570 thought_signatures: None,
1571 };
1572 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1573
1574 let result = collector.finish(None);
1575 assert_eq!(result.tool_calls.len(), 1);
1577 assert_eq!(result.tool_calls[0].name, "noop");
1578 assert_eq!(result.tool_calls[0].arguments, Value::Null);
1579 }
1580
1581 #[test]
1582 fn test_stream_collector_partial_nested_json() {
1583 let mut collector = StreamCollector::new();
1585
1586 let tc = genai::chat::ToolCall {
1587 call_id: "call_1".to_string(),
1588 fn_name: "complex_tool".to_string(),
1589 fn_arguments: Value::String(
1590 r#"{"a": {"b": [1, 2, {"c": "long_string_that_gets_truncated"#.to_string(),
1591 ),
1592 thought_signatures: None,
1593 };
1594 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1595
1596 let result = collector.finish(None);
1597 assert_eq!(result.tool_calls.len(), 0);
1599 }
1600
1601 #[test]
1602 fn test_stream_collector_truncated_then_end_event_recovers() {
1603 use genai::chat::MessageContent;
1606
1607 let mut collector = StreamCollector::new();
1608
1609 let tc1 = genai::chat::ToolCall {
1611 call_id: "call_1".to_string(),
1612 fn_name: "api".to_string(),
1613 fn_arguments: Value::String(r#"{"location": "New York", "unit": "cel"#.to_string()),
1614 thought_signatures: None,
1615 };
1616 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc1 }));
1617
1618 let end_tc = genai::chat::ToolCall {
1620 call_id: "call_1".to_string(),
1621 fn_name: String::new(),
1622 fn_arguments: Value::String(
1623 r#"{"location": "New York", "unit": "celsius"}"#.to_string(),
1624 ),
1625 thought_signatures: None,
1626 };
1627 let end = StreamEnd {
1628 captured_content: Some(MessageContent::from_tool_calls(vec![end_tc])),
1629 ..Default::default()
1630 };
1631 collector.process(ChatStreamEvent::End(end));
1632
1633 let result = collector.finish(None);
1634 assert_eq!(result.tool_calls.len(), 1);
1635 assert_eq!(
1637 result.tool_calls[0].arguments,
1638 json!({"location": "New York", "unit": "celsius"})
1639 );
1640 }
1641
1642 #[test]
1643 fn test_stream_collector_valid_json_args_control() {
1644 let mut collector = StreamCollector::new();
1646
1647 let tc = genai::chat::ToolCall {
1648 call_id: "call_1".to_string(),
1649 fn_name: "get_weather".to_string(),
1650 fn_arguments: Value::String(
1651 r#"{"location": "San Francisco", "units": "metric"}"#.to_string(),
1652 ),
1653 thought_signatures: None,
1654 };
1655 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1656
1657 let result = collector.finish(None);
1658 assert_eq!(result.tool_calls.len(), 1);
1659 assert_eq!(
1660 result.tool_calls[0].arguments,
1661 json!({"location": "San Francisco", "units": "metric"})
1662 );
1663 }
1664
1665 #[test]
1674 fn test_stream_collector_end_event_no_tool_calls_preserves_streamed() {
1675 use genai::chat::StreamEnd;
1678
1679 let mut collector = StreamCollector::new();
1680
1681 let tc = genai::chat::ToolCall {
1683 call_id: "call_1".to_string(),
1684 fn_name: "search".to_string(),
1685 fn_arguments: Value::String(r#"{"q":"test"}"#.to_string()),
1686 thought_signatures: None,
1687 };
1688 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1689
1690 let end = StreamEnd {
1692 captured_content: None,
1693 ..Default::default()
1694 };
1695 collector.process(ChatStreamEvent::End(end));
1696
1697 let result = collector.finish(None);
1698 assert_eq!(
1699 result.tool_calls.len(),
1700 1,
1701 "Streamed tool calls should be preserved"
1702 );
1703 assert_eq!(result.tool_calls[0].name, "search");
1704 assert_eq!(result.tool_calls[0].arguments, json!({"q": "test"}));
1705 }
1706
1707 #[test]
1708 fn test_stream_collector_end_event_overrides_tool_name() {
1709 use genai::chat::MessageContent;
1711
1712 let mut collector = StreamCollector::new();
1713
1714 let tc = genai::chat::ToolCall {
1716 call_id: "call_1".to_string(),
1717 fn_name: "search".to_string(),
1718 fn_arguments: Value::String(r#"{"q":"test"}"#.to_string()),
1719 thought_signatures: None,
1720 };
1721 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1722
1723 let end_tc = genai::chat::ToolCall {
1725 call_id: "call_1".to_string(),
1726 fn_name: "web_search".to_string(), fn_arguments: Value::String(r#"{"q":"test"}"#.to_string()),
1728 thought_signatures: None,
1729 };
1730 let end = StreamEnd {
1731 captured_content: Some(MessageContent::from_tool_calls(vec![end_tc])),
1732 ..Default::default()
1733 };
1734 collector.process(ChatStreamEvent::End(end));
1735
1736 let result = collector.finish(None);
1737 assert_eq!(result.tool_calls.len(), 1);
1738 assert_eq!(result.tool_calls[0].name, "search");
1741 }
1742
1743 #[test]
1744 fn test_stream_collector_whitespace_only_tool_name_filtered() {
1745 let mut collector = StreamCollector::new();
1747
1748 let tc = genai::chat::ToolCall {
1749 call_id: "ghost_1".to_string(),
1750 fn_name: " ".to_string(), fn_arguments: Value::String("{}".to_string()),
1752 thought_signatures: None,
1753 };
1754 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1755
1756 let result = collector.finish(None);
1757 assert_eq!(
1761 result.tool_calls.len(),
1762 1,
1763 "Whitespace-only names are currently NOT filtered (document behavior)"
1764 );
1765 }
1766
1767 fn tc_chunk(call_id: &str, fn_name: &str, args: &str) -> ChatStreamEvent {
1773 ChatStreamEvent::ToolCallChunk(ToolChunk {
1774 tool_call: genai::chat::ToolCall {
1775 call_id: call_id.to_string(),
1776 fn_name: fn_name.to_string(),
1777 fn_arguments: Value::String(args.to_string()),
1778 thought_signatures: None,
1779 },
1780 })
1781 }
1782
1783 #[test]
1784 fn test_stream_collector_two_tool_calls_sequential() {
1785 let mut collector = StreamCollector::new();
1787
1788 collector.process(tc_chunk("tc_1", "search", r#"{"q":"foo"}"#));
1789 collector.process(tc_chunk("tc_2", "fetch", r#"{"url":"https://x.com"}"#));
1790
1791 let result = collector.finish(None);
1792 assert_eq!(result.tool_calls.len(), 2);
1793
1794 let names: Vec<&str> = result
1795 .tool_calls
1796 .iter()
1797 .map(|tc| tc.name.as_str())
1798 .collect();
1799 assert!(names.contains(&"search"));
1800 assert!(names.contains(&"fetch"));
1801
1802 let search = result
1803 .tool_calls
1804 .iter()
1805 .find(|tc| tc.name == "search")
1806 .unwrap();
1807 assert_eq!(search.arguments, json!({"q": "foo"}));
1808
1809 let fetch = result
1810 .tool_calls
1811 .iter()
1812 .find(|tc| tc.name == "fetch")
1813 .unwrap();
1814 assert_eq!(fetch.arguments, json!({"url": "https://x.com"}));
1815 }
1816
1817 #[test]
1818 fn test_stream_collector_two_tool_calls_interleaved_chunks() {
1819 let mut collector = StreamCollector::new();
1827
1828 collector.process(tc_chunk("tc_a", "search", ""));
1830 collector.process(tc_chunk("tc_b", "fetch", ""));
1831
1832 collector.process(tc_chunk("tc_a", "search", r#"{"q":"#));
1834 collector.process(tc_chunk("tc_b", "fetch", r#"{"url":"#));
1835
1836 collector.process(tc_chunk("tc_a", "search", r#"{"q":"a"}"#));
1838 collector.process(tc_chunk("tc_b", "fetch", r#"{"url":"b"}"#));
1839
1840 let result = collector.finish(None);
1841 assert_eq!(result.tool_calls.len(), 2);
1842
1843 let search = result
1844 .tool_calls
1845 .iter()
1846 .find(|tc| tc.name == "search")
1847 .unwrap();
1848 assert_eq!(search.arguments, json!({"q": "a"}));
1849
1850 let fetch = result
1851 .tool_calls
1852 .iter()
1853 .find(|tc| tc.name == "fetch")
1854 .unwrap();
1855 assert_eq!(fetch.arguments, json!({"url": "b"}));
1856 }
1857
1858 #[test]
1859 fn test_stream_collector_tool_call_interleaved_with_text() {
1860 let mut collector = StreamCollector::new();
1862
1863 collector.process(ChatStreamEvent::Chunk(StreamChunk {
1864 content: "I will ".to_string(),
1865 }));
1866 collector.process(tc_chunk("tc_1", "search", ""));
1867 collector.process(ChatStreamEvent::Chunk(StreamChunk {
1868 content: "search ".to_string(),
1869 }));
1870 collector.process(tc_chunk("tc_1", "search", r#"{"q":"test"}"#));
1871 collector.process(ChatStreamEvent::Chunk(StreamChunk {
1872 content: "for you.".to_string(),
1873 }));
1874
1875 let result = collector.finish(None);
1876 assert_eq!(result.text, "I will search for you.");
1878 assert_eq!(result.tool_calls.len(), 1);
1880 assert_eq!(result.tool_calls[0].arguments, json!({"q": "test"}));
1881 }
1882
1883 #[test]
1884 fn test_last_tool_call_with_null_args_dropped_at_max_tokens() {
1885 let mut collector = StreamCollector::new();
1888
1889 let tc1 = genai::chat::ToolCall {
1891 call_id: "c1".to_string(),
1892 fn_name: "search".to_string(),
1893 fn_arguments: Value::String(r#"{"q":"rust"}"#.to_string()),
1894 thought_signatures: None,
1895 };
1896 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc1 }));
1897
1898 let tc2 = genai::chat::ToolCall {
1900 call_id: "c2".to_string(),
1901 fn_name: "calcu".to_string(), fn_arguments: Value::String(String::new()),
1903 thought_signatures: None,
1904 };
1905 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc2 }));
1906
1907 collector.usage = Some(genai::chat::Usage {
1909 prompt_tokens: Some(100),
1910 completion_tokens: Some(4096),
1911 ..Default::default()
1912 });
1913
1914 let result = collector.finish(Some(4096));
1915 assert_eq!(result.tool_calls.len(), 1);
1917 assert_eq!(result.tool_calls[0].name, "search");
1918 assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
1920 }
1921
1922 #[test]
1923 fn test_single_tool_call_with_null_args_at_max_tokens_triggers_max_tokens() {
1924 let mut collector = StreamCollector::new();
1927
1928 collector.process(ChatStreamEvent::Chunk(StreamChunk {
1929 content: "Let me search".to_string(),
1930 }));
1931
1932 let tc = genai::chat::ToolCall {
1933 call_id: "c1".to_string(),
1934 fn_name: "sear".to_string(), fn_arguments: Value::String(String::new()),
1936 thought_signatures: None,
1937 };
1938 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1939
1940 collector.usage = Some(genai::chat::Usage {
1941 prompt_tokens: Some(100),
1942 completion_tokens: Some(4096),
1943 ..Default::default()
1944 });
1945
1946 let result = collector.finish(Some(4096));
1947 assert_eq!(result.tool_calls.len(), 0);
1949 assert_eq!(result.stop_reason, Some(StopReason::MaxTokens));
1951 assert_eq!(result.text, "Let me search");
1952 }
1953
1954 #[test]
1955 fn test_complete_tool_calls_not_dropped_at_max_tokens() {
1956 let mut collector = StreamCollector::new();
1959
1960 let tc = genai::chat::ToolCall {
1961 call_id: "c1".to_string(),
1962 fn_name: "search".to_string(),
1963 fn_arguments: Value::String(r#"{"q":"test"}"#.to_string()),
1964 thought_signatures: None,
1965 };
1966 collector.process(ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call: tc }));
1967
1968 collector.usage = Some(genai::chat::Usage {
1969 prompt_tokens: Some(100),
1970 completion_tokens: Some(4096),
1971 ..Default::default()
1972 });
1973
1974 let result = collector.finish(Some(4096));
1975 assert_eq!(result.tool_calls.len(), 1);
1976 assert_eq!(result.tool_calls[0].name, "search");
1977 assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
1978 }
1979
1980 #[test]
1981 fn test_into_partial_text_returns_accumulated_text() {
1982 let mut collector = StreamCollector::new();
1983 collector.process(ChatStreamEvent::Chunk(genai::chat::StreamChunk {
1984 content: "Hello ".to_string(),
1985 }));
1986 collector.process(ChatStreamEvent::Chunk(genai::chat::StreamChunk {
1987 content: "world".to_string(),
1988 }));
1989 assert_eq!(collector.into_partial_text(), "Hello world");
1990 }
1991
1992 #[test]
1993 fn test_into_partial_text_empty_when_no_text() {
1994 let collector = StreamCollector::new();
1995 assert_eq!(collector.into_partial_text(), "");
1996 }
1997
1998 #[test]
1999 fn test_recovery_checkpoint_uses_partial_text_when_no_tool_call_seen() {
2000 let mut collector = StreamCollector::new();
2001 collector.process(ChatStreamEvent::Chunk(genai::chat::StreamChunk {
2002 content: "Hello".to_string(),
2003 }));
2004 assert_eq!(
2005 collector.into_recovery_checkpoint(),
2006 StreamRecoveryCheckpoint::PartialText("Hello".to_string())
2007 );
2008 }
2009
2010 #[test]
2011 fn test_recovery_checkpoint_marks_tool_call_observed() {
2012 let mut collector = StreamCollector::new();
2013 collector.process(ChatStreamEvent::ToolCallChunk(genai::chat::ToolChunk {
2014 tool_call: genai::chat::ToolCall {
2015 call_id: "call_1".to_string(),
2016 fn_name: "echo".to_string(),
2017 fn_arguments: Value::String("{\"message\":\"hi".to_string()),
2018 thought_signatures: None,
2019 },
2020 }));
2021 assert_eq!(
2022 collector.into_recovery_checkpoint(),
2023 StreamRecoveryCheckpoint::ToolCallObserved
2024 );
2025 }
2026
2027 #[test]
2028 fn test_recovery_checkpoint_marks_no_payload_when_stream_is_empty() {
2029 let collector = StreamCollector::new();
2030 assert_eq!(
2031 collector.into_recovery_checkpoint(),
2032 StreamRecoveryCheckpoint::NoPayload
2033 );
2034 }
2035}