1use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::borrow::Cow;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio_util::sync::CancellationToken;
16
17pub use crate::stream_error_kind::StreamErrorKind;
18use crate::types::{
19 AgentContext, AssistantMessage, ContentBlock, Cost, ModelSpec, StopReason, Usage,
20};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum StreamTransport {
28 #[default]
30 Sse,
31}
32
33#[derive(Debug, Clone, Default)]
41pub enum CacheStrategy {
42 #[default]
44 None,
45 Auto,
48 Anthropic,
51 Google {
53 ttl: Duration,
55 },
56}
57
58pub type OnRawPayload = Arc<dyn Fn(&str) + Send + Sync>;
66
67#[derive(Clone, Default)]
71pub struct StreamOptions {
72 pub temperature: Option<f64>,
74 pub max_tokens: Option<u64>,
76 pub session_id: Option<String>,
78 pub api_key: Option<String>,
80 pub transport: StreamTransport,
82 pub cache_strategy: CacheStrategy,
84 pub on_raw_payload: Option<OnRawPayload>,
86}
87
88impl std::fmt::Debug for StreamOptions {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("StreamOptions")
91 .field("temperature", &self.temperature)
92 .field("max_tokens", &self.max_tokens)
93 .field("session_id", &self.session_id)
94 .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
95 .field("transport", &self.transport)
96 .field("cache_strategy", &self.cache_strategy)
97 .field(
98 "on_raw_payload",
99 &self.on_raw_payload.as_ref().map(|_| "<callback>"),
100 )
101 .finish()
102 }
103}
104
105#[non_exhaustive]
113#[derive(Debug, Clone)]
114pub enum AssistantMessageEvent {
115 Start,
117
118 TextStart { content_index: usize },
120 TextDelta { content_index: usize, delta: String },
122 TextEnd { content_index: usize },
124
125 ThinkingStart { content_index: usize },
127 ThinkingDelta { content_index: usize, delta: String },
129 ThinkingEnd {
132 content_index: usize,
133 signature: Option<String>,
134 },
135
136 ToolCallStart {
138 content_index: usize,
139 id: String,
140 name: String,
141 },
142 ToolCallDelta { content_index: usize, delta: String },
144 ToolCallEnd { content_index: usize },
146
147 Done {
149 stop_reason: StopReason,
150 usage: Usage,
151 cost: Cost,
152 },
153
154 Error {
156 stop_reason: StopReason,
157 error_message: String,
158 usage: Option<Usage>,
159 error_kind: Option<StreamErrorKind>,
164 },
165}
166
167impl AssistantMessageEvent {
168 pub fn error(message: impl Into<String>) -> Self {
174 Self::Error {
175 stop_reason: StopReason::Error,
176 error_message: message.into(),
177 usage: None,
178 error_kind: None,
179 }
180 }
181
182 pub fn error_throttled(message: impl Into<String>) -> Self {
187 Self::Error {
188 stop_reason: StopReason::Error,
189 error_message: message.into(),
190 usage: None,
191 error_kind: Some(StreamErrorKind::Throttled),
192 }
193 }
194
195 pub fn error_context_overflow(message: impl Into<String>) -> Self {
200 Self::Error {
201 stop_reason: StopReason::Error,
202 error_message: message.into(),
203 usage: None,
204 error_kind: Some(StreamErrorKind::ContextWindowExceeded),
205 }
206 }
207
208 pub fn error_auth(message: impl Into<String>) -> Self {
213 Self::Error {
214 stop_reason: StopReason::Error,
215 error_message: message.into(),
216 usage: None,
217 error_kind: Some(StreamErrorKind::Auth),
218 }
219 }
220
221 pub fn error_network(message: impl Into<String>) -> Self {
226 Self::Error {
227 stop_reason: StopReason::Error,
228 error_message: message.into(),
229 usage: None,
230 error_kind: Some(StreamErrorKind::Network),
231 }
232 }
233
234 pub fn error_content_filtered(message: impl Into<String>) -> Self {
239 Self::Error {
240 stop_reason: StopReason::Error,
241 error_message: message.into(),
242 usage: None,
243 error_kind: Some(StreamErrorKind::ContentFiltered),
244 }
245 }
246
247 pub fn text_response(text: &str) -> Vec<Self> {
253 vec![
254 Self::Start,
255 Self::TextStart { content_index: 0 },
256 Self::TextDelta {
257 content_index: 0,
258 delta: text.to_string(),
259 },
260 Self::TextEnd { content_index: 0 },
261 Self::Done {
262 stop_reason: StopReason::Stop,
263 usage: Usage::default(),
264 cost: Cost::default(),
265 },
266 ]
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum AssistantMessageDelta {
279 Text {
281 content_index: usize,
282 delta: Cow<'static, str>,
283 },
284 Thinking {
286 content_index: usize,
287 delta: Cow<'static, str>,
288 },
289 ToolCall {
291 content_index: usize,
292 delta: Cow<'static, str>,
293 },
294}
295
296pub trait StreamFn: Send + Sync {
308 fn stream<'a>(
315 &'a self,
316 model: &'a ModelSpec,
317 context: &'a AgentContext,
318 options: &'a StreamOptions,
319 cancellation_token: CancellationToken,
320 ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>;
321}
322
323pub fn sanitize_incomplete_tool_calls(message: &mut AssistantMessage) -> usize {
346 let mut fixed = 0;
347 for block in &mut message.content {
348 if let ContentBlock::ToolCall {
349 arguments,
350 partial_json,
351 ..
352 } = block
353 {
354 let needs_fix = partial_json.is_some() || !arguments.is_object();
355 if needs_fix {
356 *arguments = Value::Object(serde_json::Map::new());
357 *partial_json = None;
358 fixed += 1;
359 }
360 }
361 }
362 fixed
363}
364
365#[allow(clippy::too_many_lines)]
375pub fn accumulate_message(
376 events: Vec<AssistantMessageEvent>,
377 provider: &str,
378 model_id: &str,
379) -> Result<AssistantMessage, String> {
380 fn ensure_block_open(
381 open_blocks: &[bool],
382 content_index: usize,
383 event_name: &str,
384 ) -> Result<(), String> {
385 match open_blocks.get(content_index) {
386 Some(false) => Err(format!(
387 "{event_name}: block at index {content_index} is already closed"
388 )),
389 Some(true) | None => Ok(()),
390 }
391 }
392
393 fn all_open_blocks_are_tool_calls(content: &[ContentBlock], open_blocks: &[bool]) -> bool {
394 open_blocks
395 .iter()
396 .enumerate()
397 .filter(|(_, open)| **open)
398 .all(|(content_index, _)| {
399 matches!(
400 content.get(content_index),
401 Some(ContentBlock::ToolCall { .. })
402 )
403 })
404 }
405
406 let mut content: Option<Vec<ContentBlock>> = None;
407 let mut open_blocks: Vec<bool> = Vec::new();
411 let mut stop_reason: Option<StopReason> = None;
412 let mut usage: Option<Usage> = None;
413 let mut cost: Option<Cost> = None;
414 let mut error_message: Option<String> = None;
415 let mut error_kind: Option<StreamErrorKind> = None;
416 let mut saw_start = false;
417 let mut saw_terminal = false;
418
419 let tolerate_truncated_tool_args = events.iter().any(|e| {
425 matches!(
426 e,
427 AssistantMessageEvent::Done {
428 stop_reason: StopReason::Length,
429 ..
430 } | AssistantMessageEvent::Error {
431 stop_reason: StopReason::Length,
432 ..
433 }
434 )
435 });
436
437 for event in events {
438 match &event {
440 AssistantMessageEvent::TextStart { .. }
441 | AssistantMessageEvent::TextDelta { .. }
442 | AssistantMessageEvent::TextEnd { .. }
443 | AssistantMessageEvent::ThinkingStart { .. }
444 | AssistantMessageEvent::ThinkingDelta { .. }
445 | AssistantMessageEvent::ThinkingEnd { .. }
446 | AssistantMessageEvent::ToolCallStart { .. }
447 | AssistantMessageEvent::ToolCallDelta { .. }
448 | AssistantMessageEvent::ToolCallEnd { .. } => {
449 if saw_terminal {
450 return Err("content event after terminal event".into());
451 }
452 }
453 AssistantMessageEvent::Done { .. } | AssistantMessageEvent::Error { .. } => {
454 if saw_terminal {
455 return Err("duplicate terminal event".into());
456 }
457 }
458 AssistantMessageEvent::Start => {}
459 }
460
461 match event {
462 AssistantMessageEvent::Start => {
463 if saw_start {
464 return Err("duplicate Start event".into());
465 }
466 saw_start = true;
467 content = Some(Vec::new());
468 }
469
470 AssistantMessageEvent::TextStart { content_index } => {
471 let blocks = content.as_mut().ok_or("TextStart before Start")?;
472 if content_index != blocks.len() {
473 return Err(format!(
474 "TextStart content_index {content_index} != content length {}",
475 blocks.len()
476 ));
477 }
478 blocks.push(ContentBlock::Text {
479 text: String::new(),
480 });
481 open_blocks.push(true);
482 }
483
484 AssistantMessageEvent::TextDelta {
485 content_index,
486 delta,
487 } => {
488 let blocks = content.as_mut().ok_or("TextDelta before Start")?;
489 ensure_block_open(&open_blocks, content_index, "TextDelta")?;
490 let block = blocks
491 .get_mut(content_index)
492 .ok_or_else(|| format!("TextDelta: invalid content_index {content_index}"))?;
493 match block {
494 ContentBlock::Text { text } => text.push_str(&delta),
495 _ => {
496 return Err(format!(
497 "TextDelta: block at index {content_index} is not Text"
498 ));
499 }
500 }
501 }
502
503 AssistantMessageEvent::TextEnd { content_index } => {
504 let blocks = content.as_ref().ok_or("TextEnd before Start")?;
505 let block = blocks
506 .get(content_index)
507 .ok_or_else(|| format!("TextEnd: invalid content_index {content_index}"))?;
508 if !matches!(block, ContentBlock::Text { .. }) {
509 return Err(format!(
510 "TextEnd: block at index {content_index} is not Text"
511 ));
512 }
513 ensure_block_open(&open_blocks, content_index, "TextEnd")?;
514 if let Some(open) = open_blocks.get_mut(content_index) {
515 *open = false;
516 }
517 }
518
519 AssistantMessageEvent::ThinkingStart { content_index } => {
520 let blocks = content.as_mut().ok_or("ThinkingStart before Start")?;
521 if content_index != blocks.len() {
522 return Err(format!(
523 "ThinkingStart content_index {content_index} != content length {}",
524 blocks.len()
525 ));
526 }
527 blocks.push(ContentBlock::Thinking {
528 thinking: String::new(),
529 signature: None,
530 });
531 open_blocks.push(true);
532 }
533
534 AssistantMessageEvent::ThinkingDelta {
535 content_index,
536 delta,
537 } => {
538 let blocks = content.as_mut().ok_or("ThinkingDelta before Start")?;
539 ensure_block_open(&open_blocks, content_index, "ThinkingDelta")?;
540 let block = blocks.get_mut(content_index).ok_or_else(|| {
541 format!("ThinkingDelta: invalid content_index {content_index}")
542 })?;
543 match block {
544 ContentBlock::Thinking { thinking, .. } => thinking.push_str(&delta),
545 _ => {
546 return Err(format!(
547 "ThinkingDelta: block at index {content_index} is not Thinking"
548 ));
549 }
550 }
551 }
552
553 AssistantMessageEvent::ThinkingEnd {
554 content_index,
555 signature,
556 } => {
557 let blocks = content.as_mut().ok_or("ThinkingEnd before Start")?;
558 ensure_block_open(&open_blocks, content_index, "ThinkingEnd")?;
559 let block = blocks
560 .get_mut(content_index)
561 .ok_or_else(|| format!("ThinkingEnd: invalid content_index {content_index}"))?;
562 match block {
563 ContentBlock::Thinking { signature: sig, .. } => *sig = signature,
564 _ => {
565 return Err(format!(
566 "ThinkingEnd: block at index {content_index} is not Thinking"
567 ));
568 }
569 }
570 if let Some(open) = open_blocks.get_mut(content_index) {
571 *open = false;
572 }
573 }
574
575 AssistantMessageEvent::ToolCallStart {
576 content_index,
577 id,
578 name,
579 } => {
580 let blocks = content.as_mut().ok_or("ToolCallStart before Start")?;
581 if content_index != blocks.len() {
582 return Err(format!(
583 "ToolCallStart content_index {content_index} != content length {}",
584 blocks.len()
585 ));
586 }
587 blocks.push(ContentBlock::ToolCall {
588 id,
589 name,
590 arguments: Value::Null,
591 partial_json: Some(String::new()),
592 });
593 open_blocks.push(true);
594 }
595
596 AssistantMessageEvent::ToolCallDelta {
597 content_index,
598 delta,
599 } => {
600 let blocks = content.as_mut().ok_or("ToolCallDelta before Start")?;
601 ensure_block_open(&open_blocks, content_index, "ToolCallDelta")?;
602 let block = blocks.get_mut(content_index).ok_or_else(|| {
603 format!("ToolCallDelta: invalid content_index {content_index}")
604 })?;
605 match block {
606 ContentBlock::ToolCall { partial_json, .. } => {
607 let pj = partial_json
608 .as_mut()
609 .ok_or("ToolCallDelta: partial_json already consumed")?;
610 pj.push_str(&delta);
611 }
612 _ => {
613 return Err(format!(
614 "ToolCallDelta: block at index {content_index} is not ToolCall"
615 ));
616 }
617 }
618 }
619
620 AssistantMessageEvent::ToolCallEnd { content_index } => {
621 let blocks = content.as_mut().ok_or("ToolCallEnd before Start")?;
622 let block = blocks
623 .get_mut(content_index)
624 .ok_or_else(|| format!("ToolCallEnd: invalid content_index {content_index}"))?;
625 ensure_block_open(&open_blocks, content_index, "ToolCallEnd")?;
626 match block {
627 ContentBlock::ToolCall {
628 arguments,
629 partial_json,
630 ..
631 } => {
632 let json_str = partial_json
633 .as_ref()
634 .ok_or("ToolCallEnd: partial_json already consumed")?
635 .clone();
636 if json_str.is_empty() {
637 *arguments = Value::Object(serde_json::Map::new());
638 *partial_json = None;
639 } else {
640 match serde_json::from_str::<Value>(&json_str) {
641 Ok(v) => {
642 *arguments = v;
643 *partial_json = None;
644 }
645 Err(e) => {
646 if tolerate_truncated_tool_args {
647 } else {
651 return Err(format!(
652 "ToolCallEnd: failed to parse arguments JSON: {e}"
653 ));
654 }
655 }
656 }
657 }
658 }
659 _ => {
660 return Err(format!(
661 "ToolCallEnd: block at index {content_index} is not ToolCall"
662 ));
663 }
664 }
665 if let Some(open) = open_blocks.get_mut(content_index) {
666 *open = false;
667 }
668 }
669
670 AssistantMessageEvent::Done {
671 stop_reason: sr,
672 usage: u,
673 cost: c,
674 } => {
675 if let Some(idx) = open_blocks.iter().position(|open| *open) {
676 let content = content.as_ref().ok_or("Done before Start")?;
677 if tolerate_truncated_tool_args
678 && all_open_blocks_are_tool_calls(content, &open_blocks)
679 {
680 tracing::debug!(
685 "Done(Length) with unterminated content block at index {idx} — tolerating for max-tokens recovery"
686 );
687 } else {
688 return Err(format!(
689 "Done received with unterminated content block at index {idx}"
690 ));
691 }
692 }
693 stop_reason = Some(sr);
694 usage = Some(u);
695 cost = Some(c);
696 saw_terminal = true;
697 }
698
699 AssistantMessageEvent::Error {
700 stop_reason: sr,
701 error_message: em,
702 usage: u,
703 error_kind: ek,
704 } => {
705 stop_reason = Some(sr);
706 error_message = Some(em);
707 error_kind = ek;
708 if let Some(u) = u {
709 usage = Some(u);
710 }
711 saw_terminal = true;
712 }
713 }
714 }
715
716 let content = content.ok_or("no Start event found")?;
717 let stop_reason = stop_reason.ok_or("no terminal event (Done or Error) found")?;
718
719 let timestamp = crate::util::now_timestamp();
720
721 Ok(AssistantMessage {
722 content,
723 provider: provider.to_owned(),
724 model_id: model_id.to_owned(),
725 usage: usage.unwrap_or_default(),
726 cost: cost.unwrap_or_default(),
727 stop_reason,
728 error_message,
729 error_kind,
730 timestamp,
731 cache_hint: None,
732 })
733}
734
735const _: () = {
738 const fn assert_send_sync<T: Send + Sync>() {}
739
740 assert_send_sync::<StreamErrorKind>();
741 assert_send_sync::<StreamTransport>();
742 assert_send_sync::<StreamOptions>();
743 assert_send_sync::<AssistantMessageEvent>();
744 assert_send_sync::<AssistantMessageDelta>();
745};
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[test]
752 fn done_with_unterminated_text_block_is_rejected() {
753 let events = vec![
756 AssistantMessageEvent::Start,
757 AssistantMessageEvent::TextStart { content_index: 0 },
758 AssistantMessageEvent::TextDelta {
759 content_index: 0,
760 delta: "hi".into(),
761 },
762 AssistantMessageEvent::Done {
763 stop_reason: StopReason::Stop,
764 usage: Usage::default(),
765 cost: Cost::default(),
766 },
767 ];
768 let err = accumulate_message(events, "test", "test").unwrap_err();
769 assert!(err.contains("unterminated content block"), "got: {err}");
770 }
771
772 #[test]
773 fn done_with_unterminated_tool_call_block_is_rejected() {
774 let events = vec![
776 AssistantMessageEvent::Start,
777 AssistantMessageEvent::ToolCallStart {
778 content_index: 0,
779 id: "t1".into(),
780 name: "foo".into(),
781 },
782 AssistantMessageEvent::ToolCallDelta {
783 content_index: 0,
784 delta: "{}".into(),
785 },
786 AssistantMessageEvent::Done {
787 stop_reason: StopReason::ToolUse,
788 usage: Usage::default(),
789 cost: Cost::default(),
790 },
791 ];
792 let err = accumulate_message(events, "test", "test").unwrap_err();
793 assert!(err.contains("unterminated content block"), "got: {err}");
794 }
795
796 #[test]
797 fn done_with_all_blocks_terminated_succeeds() {
798 let events = vec![
799 AssistantMessageEvent::Start,
800 AssistantMessageEvent::TextStart { content_index: 0 },
801 AssistantMessageEvent::TextDelta {
802 content_index: 0,
803 delta: "ok".into(),
804 },
805 AssistantMessageEvent::TextEnd { content_index: 0 },
806 AssistantMessageEvent::Done {
807 stop_reason: StopReason::Stop,
808 usage: Usage::default(),
809 cost: Cost::default(),
810 },
811 ];
812 let msg = accumulate_message(events, "test", "test").expect("should succeed");
813 assert_eq!(msg.content.len(), 1);
814 }
815
816 #[test]
817 fn error_with_unterminated_block_is_allowed() {
818 let events = vec![
821 AssistantMessageEvent::Start,
822 AssistantMessageEvent::TextStart { content_index: 0 },
823 AssistantMessageEvent::Error {
824 stop_reason: StopReason::Error,
825 error_message: "boom".into(),
826 usage: None,
827 error_kind: None,
828 },
829 ];
830 let msg = accumulate_message(events, "test", "test").expect("error terminal ok");
831 assert_eq!(msg.error_message.as_deref(), Some("boom"));
832 }
833
834 #[test]
835 fn error_constructor_sets_kind_none() {
836 let event = AssistantMessageEvent::error("boom");
837 match event {
838 AssistantMessageEvent::Error { error_kind, .. } => {
839 assert_eq!(error_kind, None);
840 }
841 other => panic!("expected Error, got {other:?}"),
842 }
843 }
844
845 #[test]
846 fn error_throttled_constructor_sets_kind() {
847 let event = AssistantMessageEvent::error_throttled("rate limited");
848 match event {
849 AssistantMessageEvent::Error {
850 error_kind,
851 error_message,
852 ..
853 } => {
854 assert_eq!(error_kind, Some(StreamErrorKind::Throttled));
855 assert_eq!(error_message, "rate limited");
856 }
857 other => panic!("expected Error, got {other:?}"),
858 }
859 }
860
861 #[test]
862 fn error_context_overflow_constructor_sets_kind() {
863 let event = AssistantMessageEvent::error_context_overflow("too long");
864 match event {
865 AssistantMessageEvent::Error { error_kind, .. } => {
866 assert_eq!(error_kind, Some(StreamErrorKind::ContextWindowExceeded));
867 }
868 other => panic!("expected Error, got {other:?}"),
869 }
870 }
871
872 #[test]
873 fn error_auth_constructor_sets_kind() {
874 let event = AssistantMessageEvent::error_auth("bad key");
875 match event {
876 AssistantMessageEvent::Error { error_kind, .. } => {
877 assert_eq!(error_kind, Some(StreamErrorKind::Auth));
878 }
879 other => panic!("expected Error, got {other:?}"),
880 }
881 }
882
883 #[test]
884 fn error_network_constructor_sets_kind() {
885 let event = AssistantMessageEvent::error_network("timeout");
886 match event {
887 AssistantMessageEvent::Error { error_kind, .. } => {
888 assert_eq!(error_kind, Some(StreamErrorKind::Network));
889 }
890 other => panic!("expected Error, got {other:?}"),
891 }
892 }
893
894 #[test]
895 fn error_content_filtered_constructor_sets_kind() {
896 let event = AssistantMessageEvent::error_content_filtered("blocked by safety filter");
897 match event {
898 AssistantMessageEvent::Error {
899 error_kind,
900 error_message,
901 ..
902 } => {
903 assert_eq!(error_kind, Some(StreamErrorKind::ContentFiltered));
904 assert_eq!(error_message, "blocked by safety filter");
905 }
906 other => panic!("expected Error, got {other:?}"),
907 }
908 }
909
910 #[test]
911 fn text_response_produces_valid_event_sequence() {
912 let events = AssistantMessageEvent::text_response("hello world");
913 assert_eq!(events.len(), 5);
914 assert!(matches!(events[0], AssistantMessageEvent::Start));
915 assert!(matches!(
916 events[1],
917 AssistantMessageEvent::TextStart { content_index: 0 }
918 ));
919 match &events[2] {
920 AssistantMessageEvent::TextDelta {
921 content_index,
922 delta,
923 } => {
924 assert_eq!(*content_index, 0);
925 assert_eq!(delta, "hello world");
926 }
927 other => panic!("expected TextDelta, got {other:?}"),
928 }
929 assert!(matches!(
930 events[3],
931 AssistantMessageEvent::TextEnd { content_index: 0 }
932 ));
933 assert!(matches!(
934 events[4],
935 AssistantMessageEvent::Done {
936 stop_reason: StopReason::Stop,
937 ..
938 }
939 ));
940 }
941
942 #[test]
946 fn done_length_with_unterminated_tool_call_is_tolerated() {
947 let events = vec![
948 AssistantMessageEvent::Start,
949 AssistantMessageEvent::ToolCallStart {
950 content_index: 0,
951 id: "tc_1".into(),
952 name: "read_file".into(),
953 },
954 AssistantMessageEvent::ToolCallDelta {
955 content_index: 0,
956 delta: r#"{"path": "/tmp"#.into(),
957 },
958 AssistantMessageEvent::Done {
959 stop_reason: StopReason::Length,
960 usage: Usage::default(),
961 cost: Cost::default(),
962 },
963 ];
964 let msg = accumulate_message(events, "test", "test")
965 .expect("Done(Length) with open tool-call block should succeed");
966 assert_eq!(msg.stop_reason, StopReason::Length);
967 match &msg.content[0] {
969 ContentBlock::ToolCall { partial_json, .. } => {
970 assert!(
971 partial_json.is_some(),
972 "partial_json should be Some for incomplete tool call"
973 );
974 }
975 other => panic!("expected ToolCall, got {other:?}"),
976 }
977 }
978
979 #[test]
980 fn done_length_with_unterminated_text_block_is_rejected() {
981 let events = vec![
982 AssistantMessageEvent::Start,
983 AssistantMessageEvent::TextStart { content_index: 0 },
984 AssistantMessageEvent::TextDelta {
985 content_index: 0,
986 delta: "partial".into(),
987 },
988 AssistantMessageEvent::Done {
989 stop_reason: StopReason::Length,
990 usage: Usage::default(),
991 cost: Cost::default(),
992 },
993 ];
994
995 let err = accumulate_message(events, "test", "test").unwrap_err();
996 assert!(err.contains("unterminated content block"), "got: {err}");
997 }
998
999 #[test]
1000 fn done_length_with_unterminated_thinking_block_is_rejected() {
1001 let events = vec![
1002 AssistantMessageEvent::Start,
1003 AssistantMessageEvent::ThinkingStart { content_index: 0 },
1004 AssistantMessageEvent::ThinkingDelta {
1005 content_index: 0,
1006 delta: "partial".into(),
1007 },
1008 AssistantMessageEvent::Done {
1009 stop_reason: StopReason::Length,
1010 usage: Usage::default(),
1011 cost: Cost::default(),
1012 },
1013 ];
1014
1015 let err = accumulate_message(events, "test", "test").unwrap_err();
1016 assert!(err.contains("unterminated content block"), "got: {err}");
1017 }
1018
1019 #[test]
1020 fn text_response_accumulates_correctly() {
1021 let events = AssistantMessageEvent::text_response("accumulated text");
1022 let msg = accumulate_message(events, "test", "test-model").expect("accumulation failed");
1023 assert_eq!(msg.content.len(), 1);
1024 assert_eq!(ContentBlock::extract_text(&msg.content), "accumulated text");
1025 assert_eq!(msg.stop_reason, StopReason::Stop);
1026 }
1027
1028 #[test]
1029 fn text_delta_after_text_end_is_rejected() {
1030 let events = vec![
1031 AssistantMessageEvent::Start,
1032 AssistantMessageEvent::TextStart { content_index: 0 },
1033 AssistantMessageEvent::TextDelta {
1034 content_index: 0,
1035 delta: "hello".into(),
1036 },
1037 AssistantMessageEvent::TextEnd { content_index: 0 },
1038 AssistantMessageEvent::TextDelta {
1039 content_index: 0,
1040 delta: " again".into(),
1041 },
1042 AssistantMessageEvent::Done {
1043 stop_reason: StopReason::Stop,
1044 usage: Usage::default(),
1045 cost: Cost::default(),
1046 },
1047 ];
1048
1049 let err = accumulate_message(events, "test", "test").unwrap_err();
1050 assert_eq!(err, "TextDelta: block at index 0 is already closed");
1051 }
1052
1053 #[test]
1054 fn duplicate_text_end_is_rejected() {
1055 let events = vec![
1056 AssistantMessageEvent::Start,
1057 AssistantMessageEvent::TextStart { content_index: 0 },
1058 AssistantMessageEvent::TextDelta {
1059 content_index: 0,
1060 delta: "hello".into(),
1061 },
1062 AssistantMessageEvent::TextEnd { content_index: 0 },
1063 AssistantMessageEvent::TextEnd { content_index: 0 },
1064 AssistantMessageEvent::Done {
1065 stop_reason: StopReason::Stop,
1066 usage: Usage::default(),
1067 cost: Cost::default(),
1068 },
1069 ];
1070
1071 let err = accumulate_message(events, "test", "test").unwrap_err();
1072 assert_eq!(err, "TextEnd: block at index 0 is already closed");
1073 }
1074
1075 #[test]
1076 fn duplicate_thinking_end_is_rejected() {
1077 let events = vec![
1078 AssistantMessageEvent::Start,
1079 AssistantMessageEvent::ThinkingStart { content_index: 0 },
1080 AssistantMessageEvent::ThinkingDelta {
1081 content_index: 0,
1082 delta: "step 1".into(),
1083 },
1084 AssistantMessageEvent::ThinkingEnd {
1085 content_index: 0,
1086 signature: Some("sig-1".into()),
1087 },
1088 AssistantMessageEvent::ThinkingEnd {
1089 content_index: 0,
1090 signature: Some("sig-2".into()),
1091 },
1092 AssistantMessageEvent::Done {
1093 stop_reason: StopReason::Stop,
1094 usage: Usage::default(),
1095 cost: Cost::default(),
1096 },
1097 ];
1098
1099 let err = accumulate_message(events, "test", "test").unwrap_err();
1100 assert_eq!(err, "ThinkingEnd: block at index 0 is already closed");
1101 }
1102
1103 #[test]
1104 fn tool_call_delta_after_end_is_rejected() {
1105 let events = vec![
1106 AssistantMessageEvent::Start,
1107 AssistantMessageEvent::ToolCallStart {
1108 content_index: 0,
1109 id: "tool-1".into(),
1110 name: "read_file".into(),
1111 },
1112 AssistantMessageEvent::ToolCallDelta {
1113 content_index: 0,
1114 delta: "{\"path\":\"/tmp/a\"}".into(),
1115 },
1116 AssistantMessageEvent::ToolCallEnd { content_index: 0 },
1117 AssistantMessageEvent::ToolCallDelta {
1118 content_index: 0,
1119 delta: ",\"extra\":true}".into(),
1120 },
1121 AssistantMessageEvent::Done {
1122 stop_reason: StopReason::ToolUse,
1123 usage: Usage::default(),
1124 cost: Cost::default(),
1125 },
1126 ];
1127
1128 let err = accumulate_message(events, "test", "test").unwrap_err();
1129 assert_eq!(err, "ToolCallDelta: block at index 0 is already closed");
1130 }
1131
1132 #[test]
1133 fn duplicate_tool_call_end_is_rejected() {
1134 let events = vec![
1135 AssistantMessageEvent::Start,
1136 AssistantMessageEvent::ToolCallStart {
1137 content_index: 0,
1138 id: "tool-1".into(),
1139 name: "read_file".into(),
1140 },
1141 AssistantMessageEvent::ToolCallDelta {
1142 content_index: 0,
1143 delta: "{\"path\":\"/tmp/a\"}".into(),
1144 },
1145 AssistantMessageEvent::ToolCallEnd { content_index: 0 },
1146 AssistantMessageEvent::ToolCallEnd { content_index: 0 },
1147 AssistantMessageEvent::Done {
1148 stop_reason: StopReason::ToolUse,
1149 usage: Usage::default(),
1150 cost: Cost::default(),
1151 },
1152 ];
1153
1154 let err = accumulate_message(events, "test", "test").unwrap_err();
1155 assert_eq!(err, "ToolCallEnd: block at index 0 is already closed");
1156 }
1157
1158 fn build_assistant_with_tool_call(
1161 arguments: Value,
1162 partial_json: Option<String>,
1163 ) -> AssistantMessage {
1164 AssistantMessage {
1165 content: vec![ContentBlock::ToolCall {
1166 id: "tc_1".into(),
1167 name: "read_file".into(),
1168 arguments,
1169 partial_json,
1170 }],
1171 provider: "test".into(),
1172 model_id: "test".into(),
1173 usage: Usage::default(),
1174 cost: Cost::default(),
1175 stop_reason: StopReason::Length,
1176 error_message: None,
1177 error_kind: None,
1178 timestamp: 0,
1179 cache_hint: None,
1180 }
1181 }
1182
1183 #[test]
1184 fn sanitize_null_arguments_with_partial_json_returns_empty_object() {
1185 let mut msg = build_assistant_with_tool_call(Value::Null, Some("{\"path\": \"/tm".into()));
1186 let fixed = sanitize_incomplete_tool_calls(&mut msg);
1187 assert_eq!(fixed, 1);
1188 match &msg.content[0] {
1189 ContentBlock::ToolCall {
1190 arguments,
1191 partial_json,
1192 ..
1193 } => {
1194 assert_eq!(*arguments, Value::Object(serde_json::Map::new()));
1195 assert!(
1196 partial_json.is_none(),
1197 "partial_json must be cleared after scrub"
1198 );
1199 }
1200 other => panic!("expected ToolCall, got {other:?}"),
1201 }
1202 }
1203
1204 #[test]
1205 fn sanitize_leaves_valid_object_arguments_untouched() {
1206 let args = serde_json::json!({ "path": "/tmp/a" });
1207 let mut msg = build_assistant_with_tool_call(args.clone(), None);
1208 let fixed = sanitize_incomplete_tool_calls(&mut msg);
1209 assert_eq!(fixed, 0);
1210 match &msg.content[0] {
1211 ContentBlock::ToolCall {
1212 arguments,
1213 partial_json,
1214 ..
1215 } => {
1216 assert_eq!(*arguments, args);
1217 assert!(partial_json.is_none());
1218 }
1219 other => panic!("expected ToolCall, got {other:?}"),
1220 }
1221 }
1222
1223 #[test]
1224 fn sanitize_coerces_non_object_arguments() {
1225 let mut msg = build_assistant_with_tool_call(Value::String("truncated".into()), None);
1228 let fixed = sanitize_incomplete_tool_calls(&mut msg);
1229 assert_eq!(fixed, 1);
1230 match &msg.content[0] {
1231 ContentBlock::ToolCall { arguments, .. } => {
1232 assert_eq!(*arguments, Value::Object(serde_json::Map::new()));
1233 }
1234 other => panic!("expected ToolCall, got {other:?}"),
1235 }
1236 }
1237
1238 #[test]
1239 fn sanitize_is_idempotent() {
1240 let mut msg = build_assistant_with_tool_call(Value::Null, Some("{\"path\":".into()));
1241 assert_eq!(sanitize_incomplete_tool_calls(&mut msg), 1);
1242 assert_eq!(sanitize_incomplete_tool_calls(&mut msg), 0);
1244 }
1245
1246 #[test]
1247 fn sanitize_preserves_non_tool_blocks() {
1248 let mut msg = AssistantMessage {
1249 content: vec![
1250 ContentBlock::Text {
1251 text: "hello".into(),
1252 },
1253 ContentBlock::ToolCall {
1254 id: "tc_1".into(),
1255 name: "foo".into(),
1256 arguments: Value::Null,
1257 partial_json: Some("{".into()),
1258 },
1259 ContentBlock::Text {
1260 text: "world".into(),
1261 },
1262 ],
1263 provider: "test".into(),
1264 model_id: "test".into(),
1265 usage: Usage::default(),
1266 cost: Cost::default(),
1267 stop_reason: StopReason::Length,
1268 error_message: None,
1269 error_kind: None,
1270 timestamp: 0,
1271 cache_hint: None,
1272 };
1273 let fixed = sanitize_incomplete_tool_calls(&mut msg);
1274 assert_eq!(fixed, 1);
1275 match &msg.content[0] {
1277 ContentBlock::Text { text } => assert_eq!(text, "hello"),
1278 other => panic!("expected Text, got {other:?}"),
1279 }
1280 match &msg.content[2] {
1281 ContentBlock::Text { text } => assert_eq!(text, "world"),
1282 other => panic!("expected Text, got {other:?}"),
1283 }
1284 }
1285
1286 #[test]
1290 fn accumulate_plus_sanitize_yields_adapter_safe_tool_call() {
1291 let events = vec![
1292 AssistantMessageEvent::Start,
1293 AssistantMessageEvent::ToolCallStart {
1294 content_index: 0,
1295 id: "tc_1".into(),
1296 name: "read_file".into(),
1297 },
1298 AssistantMessageEvent::ToolCallDelta {
1299 content_index: 0,
1300 delta: r#"{"path": "/tm"#.into(),
1301 },
1302 AssistantMessageEvent::Done {
1303 stop_reason: StopReason::Length,
1304 usage: Usage::default(),
1305 cost: Cost::default(),
1306 },
1307 ];
1308 let mut msg = accumulate_message(events, "test", "test")
1309 .expect("Done(Length) with unterminated tool-call should accumulate");
1310 match &msg.content[0] {
1312 ContentBlock::ToolCall {
1313 arguments,
1314 partial_json,
1315 ..
1316 } => {
1317 assert!(partial_json.is_some());
1318 assert!(arguments.is_null());
1319 }
1320 other => panic!("expected ToolCall, got {other:?}"),
1321 }
1322
1323 sanitize_incomplete_tool_calls(&mut msg);
1324
1325 match &msg.content[0] {
1327 ContentBlock::ToolCall {
1328 arguments,
1329 partial_json,
1330 ..
1331 } => {
1332 assert!(arguments.is_object());
1333 assert_eq!(arguments.as_object().unwrap().len(), 0);
1334 assert!(partial_json.is_none());
1335 }
1336 other => panic!("expected ToolCall, got {other:?}"),
1337 }
1338 }
1339}