Skip to main content

swink_agent/
stream.rs

1//! Streaming interface traits and types.
2//!
3//! Defines the `StreamFn` trait (the pluggable boundary between the harness and
4//! LLM providers), the event protocol for incremental message delivery, and a
5//! delta-accumulation function that reconstructs a finalized `AssistantMessage`
6//! from a collected sequence of events.
7
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::borrow::Cow;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio_util::sync::CancellationToken;
16
17pub use crate::stream_error_kind::StreamErrorKind;
18use crate::types::{
19    AgentContext, AssistantMessage, ContentBlock, Cost, ModelSpec, StopReason, Usage,
20};
21
22// ─── StreamTransport ─────────────────────────────────────────────────────────
23
24/// Transport protocol for streaming responses.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum StreamTransport {
28    /// Server-Sent Events (default).
29    #[default]
30    Sse,
31}
32
33// ─── CacheStrategy ──────────────────────────────────────────────────────────
34
35/// Provider-agnostic caching configuration.
36///
37/// Adapters translate this to provider-specific cache markers at request
38/// construction time. Adapters that don't support caching silently ignore
39/// the strategy.
40#[derive(Debug, Clone, Default)]
41pub enum CacheStrategy {
42    /// No caching (default) — no cache markers injected.
43    #[default]
44    None,
45    /// Adapter determines optimal cache points (e.g., system prompt + tool
46    /// definitions for Anthropic, long context for Google).
47    Auto,
48    /// Anthropic-specific: inject `cache_control: { type: "ephemeral" }`
49    /// blocks on system prompt and tool definitions.
50    Anthropic,
51    /// Google-specific: reference a `CachedContent` resource with the given TTL.
52    Google {
53        /// Time-to-live for the cached content.
54        ttl: Duration,
55    },
56}
57
58// ─── OnRawPayload ───────────────────────────────────────────────────────────
59
60/// Callback for observing raw SSE data lines before event parsing.
61///
62/// Fires synchronously with each raw `data:` line. Must return quickly
63/// (fire-and-forget semantics). Panics are caught and do not interrupt
64/// the streaming pipeline.
65pub type OnRawPayload = Arc<dyn Fn(&str) + Send + Sync>;
66
67// ─── StreamOptions ───────────────────────────────────────────────────────────
68
69/// Per-call configuration passed through to the LLM provider.
70#[derive(Clone, Default)]
71pub struct StreamOptions {
72    /// Sampling temperature (optional).
73    pub temperature: Option<f64>,
74    /// Output token limit (optional).
75    pub max_tokens: Option<u64>,
76    /// Provider-side session identifier for caching (optional).
77    pub session_id: Option<String>,
78    /// Dynamically resolved API key for this specific request (optional).
79    pub api_key: Option<String>,
80    /// Preferred transport protocol.
81    pub transport: StreamTransport,
82    /// Provider-agnostic caching configuration.
83    pub cache_strategy: CacheStrategy,
84    /// Optional callback for observing raw SSE data lines before parsing.
85    pub on_raw_payload: Option<OnRawPayload>,
86}
87
88impl std::fmt::Debug for StreamOptions {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("StreamOptions")
91            .field("temperature", &self.temperature)
92            .field("max_tokens", &self.max_tokens)
93            .field("session_id", &self.session_id)
94            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
95            .field("transport", &self.transport)
96            .field("cache_strategy", &self.cache_strategy)
97            .field(
98                "on_raw_payload",
99                &self.on_raw_payload.as_ref().map(|_| "<callback>"),
100            )
101            .finish()
102    }
103}
104
105// ─── AssistantMessageEvent ───────────────────────────────────────────────────
106
107/// An incremental event emitted by a `StreamFn` implementation.
108///
109/// Events follow a strict start/delta/end protocol per content block. Each
110/// block carries a `content_index` that identifies its position in the final
111/// message's content vec.
112#[non_exhaustive]
113#[derive(Debug, Clone)]
114pub enum AssistantMessageEvent {
115    /// The stream has opened.
116    Start,
117
118    /// A new text content block is starting at `content_index`.
119    TextStart { content_index: usize },
120    /// An incremental text fragment for the block at `content_index`.
121    TextDelta { content_index: usize, delta: String },
122    /// The text block at `content_index` is complete.
123    TextEnd { content_index: usize },
124
125    /// A new thinking content block is starting at `content_index`.
126    ThinkingStart { content_index: usize },
127    /// An incremental thinking fragment for the block at `content_index`.
128    ThinkingDelta { content_index: usize, delta: String },
129    /// The thinking block at `content_index` is complete, with an optional
130    /// provider verification signature.
131    ThinkingEnd {
132        content_index: usize,
133        signature: Option<String>,
134    },
135
136    /// A new tool call content block is starting at `content_index`.
137    ToolCallStart {
138        content_index: usize,
139        id: String,
140        name: String,
141    },
142    /// An incremental JSON argument fragment for the tool call at `content_index`.
143    ToolCallDelta { content_index: usize, delta: String },
144    /// The tool call at `content_index` is complete.
145    ToolCallEnd { content_index: usize },
146
147    /// The stream completed successfully.
148    Done {
149        stop_reason: StopReason,
150        usage: Usage,
151        cost: Cost,
152    },
153
154    /// The stream ended with an error.
155    Error {
156        stop_reason: StopReason,
157        error_message: String,
158        usage: Option<Usage>,
159        /// Optional structured error classification.
160        ///
161        /// When set, the agent loop uses this to classify the error without
162        /// falling back to string matching on `error_message`.
163        error_kind: Option<StreamErrorKind>,
164    },
165}
166
167impl AssistantMessageEvent {
168    /// Create a stream error event with no structured classification.
169    ///
170    /// Convenience constructor used by adapters when the stream encounters
171    /// an error condition. The `error_kind` is set to `None`, so the agent
172    /// loop will fall back to string-based classification.
173    pub fn error(message: impl Into<String>) -> Self {
174        Self::Error {
175            stop_reason: StopReason::Error,
176            error_message: message.into(),
177            usage: None,
178            error_kind: None,
179        }
180    }
181
182    /// Create a throttle/rate-limit error event.
183    ///
184    /// Sets [`StreamErrorKind::Throttled`] so the agent loop can classify
185    /// the error structurally.
186    pub fn error_throttled(message: impl Into<String>) -> Self {
187        Self::Error {
188            stop_reason: StopReason::Error,
189            error_message: message.into(),
190            usage: None,
191            error_kind: Some(StreamErrorKind::Throttled),
192        }
193    }
194
195    /// Create a context-window overflow error event.
196    ///
197    /// Sets [`StreamErrorKind::ContextWindowExceeded`] so the agent loop
198    /// can trigger context compaction.
199    pub fn error_context_overflow(message: impl Into<String>) -> Self {
200        Self::Error {
201            stop_reason: StopReason::Error,
202            error_message: message.into(),
203            usage: None,
204            error_kind: Some(StreamErrorKind::ContextWindowExceeded),
205        }
206    }
207
208    /// Create an authentication error event.
209    ///
210    /// Sets [`StreamErrorKind::Auth`] so the agent loop can treat this as
211    /// a non-retryable failure.
212    pub fn error_auth(message: impl Into<String>) -> Self {
213        Self::Error {
214            stop_reason: StopReason::Error,
215            error_message: message.into(),
216            usage: None,
217            error_kind: Some(StreamErrorKind::Auth),
218        }
219    }
220
221    /// Create a network/server error event.
222    ///
223    /// Sets [`StreamErrorKind::Network`] so the agent loop can classify
224    /// the error as retryable.
225    pub fn error_network(message: impl Into<String>) -> Self {
226        Self::Error {
227            stop_reason: StopReason::Error,
228            error_message: message.into(),
229            usage: None,
230            error_kind: Some(StreamErrorKind::Network),
231        }
232    }
233
234    /// Create a content-filtered error event.
235    ///
236    /// Sets [`StreamErrorKind::ContentFiltered`] so the agent loop can
237    /// treat this as a non-retryable safety policy violation.
238    pub fn error_content_filtered(message: impl Into<String>) -> Self {
239        Self::Error {
240            stop_reason: StopReason::Error,
241            error_message: message.into(),
242            usage: None,
243            error_kind: Some(StreamErrorKind::ContentFiltered),
244        }
245    }
246
247    /// Build a complete single-text-block response event sequence.
248    ///
249    /// Useful for testing and mock `StreamFn` implementations. Returns the
250    /// five events needed for a valid text-only response: `Start`, `TextStart`,
251    /// `TextDelta`, `TextEnd`, and `Done`.
252    pub fn text_response(text: &str) -> Vec<Self> {
253        vec![
254            Self::Start,
255            Self::TextStart { content_index: 0 },
256            Self::TextDelta {
257                content_index: 0,
258                delta: text.to_string(),
259            },
260            Self::TextEnd { content_index: 0 },
261            Self::Done {
262                stop_reason: StopReason::Stop,
263                usage: Usage::default(),
264                cost: Cost::default(),
265            },
266        ]
267    }
268}
269
270// ─── AssistantMessageDelta ───────────────────────────────────────────────────
271
272/// A typed incremental update during streaming, used in `MessageUpdate` events.
273///
274/// The `delta` field uses [`Cow<'static, str>`] to avoid cloning on the hot
275/// path when the caller can transfer ownership of the underlying `String`.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum AssistantMessageDelta {
279    /// An appended text string fragment.
280    Text {
281        content_index: usize,
282        delta: Cow<'static, str>,
283    },
284    /// An appended reasoning fragment.
285    Thinking {
286        content_index: usize,
287        delta: Cow<'static, str>,
288    },
289    /// An appended JSON argument fragment for a tool call.
290    ToolCall {
291        content_index: usize,
292        delta: Cow<'static, str>,
293    },
294}
295
296// ─── StreamFn Trait ──────────────────────────────────────────────────────────
297
298/// The pluggable boundary between the harness and LLM providers.
299///
300/// Callers supply an implementation that accepts a model specification, an
301/// agent context, and stream options, and returns an async stream of
302/// `AssistantMessageEvent` values. The harness consumes this stream to build
303/// up the assistant message incrementally.
304///
305/// This trait is object-safe and requires `Send + Sync` so that it can be
306/// stored behind an `Arc` and shared across async tasks.
307pub trait StreamFn: Send + Sync {
308    /// Initiate a streaming LLM call.
309    ///
310    /// The returned stream yields `AssistantMessageEvent` values following the
311    /// start/delta/end protocol. Implementations must respect the provided
312    /// `cancellation_token` — when the token is cancelled, the stream should
313    /// terminate promptly.
314    fn stream<'a>(
315        &'a self,
316        model: &'a ModelSpec,
317        context: &'a AgentContext,
318        options: &'a StreamOptions,
319        cancellation_token: CancellationToken,
320    ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>;
321}
322
323// ─── Tool-call sanitization ──────────────────────────────────────────────────
324
325/// Scrub incomplete `ToolCall` blocks in an assistant message so it can safely
326/// be replayed to a provider.
327///
328/// When a stream hits [`StopReason::Length`] mid tool-use, the resulting
329/// [`ContentBlock::ToolCall`] may carry `arguments: Value::Null` with
330/// `partial_json: Some(..)` (see `accumulate_message`). Provider adapters
331/// forward `arguments` verbatim; Anthropic/Google/Bedrock reject null inputs
332/// and OpenAI-compatible providers reject the literal string `"null"`. On the
333/// next turn this causes a 400.
334///
335/// This helper coerces any `ToolCall` block whose `partial_json` is still set
336/// OR whose `arguments` is not a JSON object into a valid empty-object call:
337/// `arguments = Value::Object({})` and `partial_json = None`. The loop pairs
338/// this with a synthetic tool-result message (see
339/// `recover_incomplete_tool_calls`) so the provider sees a well-formed pair.
340///
341/// Safe to call multiple times and safe to call on messages that contain no
342/// tool-use blocks. Returns the number of blocks modified.
343///
344/// See <https://github.com/SuperSwinkAI/Swink-Agent/issues/619>.
345pub fn sanitize_incomplete_tool_calls(message: &mut AssistantMessage) -> usize {
346    let mut fixed = 0;
347    for block in &mut message.content {
348        if let ContentBlock::ToolCall {
349            arguments,
350            partial_json,
351            ..
352        } = block
353        {
354            let needs_fix = partial_json.is_some() || !arguments.is_object();
355            if needs_fix {
356                *arguments = Value::Object(serde_json::Map::new());
357                *partial_json = None;
358                fixed += 1;
359            }
360        }
361    }
362    fixed
363}
364
365// ─── Delta Accumulation ──────────────────────────────────────────────────────
366
367/// Reconstruct a finalized `AssistantMessage` from a collected list of stream
368/// events.
369///
370/// # Errors
371///
372/// Returns a descriptive error string if the event sequence is malformed (e.g.
373/// delta for a non-existent content index, missing `Start` or terminal event).
374#[allow(clippy::too_many_lines)]
375pub fn accumulate_message(
376    events: Vec<AssistantMessageEvent>,
377    provider: &str,
378    model_id: &str,
379) -> Result<AssistantMessage, String> {
380    fn ensure_block_open(
381        open_blocks: &[bool],
382        content_index: usize,
383        event_name: &str,
384    ) -> Result<(), String> {
385        match open_blocks.get(content_index) {
386            Some(false) => Err(format!(
387                "{event_name}: block at index {content_index} is already closed"
388            )),
389            Some(true) | None => Ok(()),
390        }
391    }
392
393    fn all_open_blocks_are_tool_calls(content: &[ContentBlock], open_blocks: &[bool]) -> bool {
394        open_blocks
395            .iter()
396            .enumerate()
397            .filter(|(_, open)| **open)
398            .all(|(content_index, _)| {
399                matches!(
400                    content.get(content_index),
401                    Some(ContentBlock::ToolCall { .. })
402                )
403            })
404    }
405
406    let mut content: Option<Vec<ContentBlock>> = None;
407    // Parallel to `content`: tracks whether each block is still open (awaiting
408    // its matching `*End` event). Finalization (on `Done`) fails if any block
409    // is still open, preventing silently-corrupt assistant messages.
410    let mut open_blocks: Vec<bool> = Vec::new();
411    let mut stop_reason: Option<StopReason> = None;
412    let mut usage: Option<Usage> = None;
413    let mut cost: Option<Cost> = None;
414    let mut error_message: Option<String> = None;
415    let mut error_kind: Option<StreamErrorKind> = None;
416    let mut saw_start = false;
417    let mut saw_terminal = false;
418
419    // Pre-scan for a Length stop reason. Providers can emit `ToolCallEnd` with
420    // truncated JSON arguments (or omit it entirely) when hitting the max-token
421    // limit mid tool-call. We must preserve the incomplete block so the loop's
422    // `recover_incomplete_tool_calls` path can convert it into an error tool
423    // result and continue on the next turn. See issue #221.
424    let tolerate_truncated_tool_args = events.iter().any(|e| {
425        matches!(
426            e,
427            AssistantMessageEvent::Done {
428                stop_reason: StopReason::Length,
429                ..
430            } | AssistantMessageEvent::Error {
431                stop_reason: StopReason::Length,
432                ..
433            }
434        )
435    });
436
437    for event in events {
438        // Reject content-block events after a terminal event.
439        match &event {
440            AssistantMessageEvent::TextStart { .. }
441            | AssistantMessageEvent::TextDelta { .. }
442            | AssistantMessageEvent::TextEnd { .. }
443            | AssistantMessageEvent::ThinkingStart { .. }
444            | AssistantMessageEvent::ThinkingDelta { .. }
445            | AssistantMessageEvent::ThinkingEnd { .. }
446            | AssistantMessageEvent::ToolCallStart { .. }
447            | AssistantMessageEvent::ToolCallDelta { .. }
448            | AssistantMessageEvent::ToolCallEnd { .. } => {
449                if saw_terminal {
450                    return Err("content event after terminal event".into());
451                }
452            }
453            AssistantMessageEvent::Done { .. } | AssistantMessageEvent::Error { .. } => {
454                if saw_terminal {
455                    return Err("duplicate terminal event".into());
456                }
457            }
458            AssistantMessageEvent::Start => {}
459        }
460
461        match event {
462            AssistantMessageEvent::Start => {
463                if saw_start {
464                    return Err("duplicate Start event".into());
465                }
466                saw_start = true;
467                content = Some(Vec::new());
468            }
469
470            AssistantMessageEvent::TextStart { content_index } => {
471                let blocks = content.as_mut().ok_or("TextStart before Start")?;
472                if content_index != blocks.len() {
473                    return Err(format!(
474                        "TextStart content_index {content_index} != content length {}",
475                        blocks.len()
476                    ));
477                }
478                blocks.push(ContentBlock::Text {
479                    text: String::new(),
480                });
481                open_blocks.push(true);
482            }
483
484            AssistantMessageEvent::TextDelta {
485                content_index,
486                delta,
487            } => {
488                let blocks = content.as_mut().ok_or("TextDelta before Start")?;
489                ensure_block_open(&open_blocks, content_index, "TextDelta")?;
490                let block = blocks
491                    .get_mut(content_index)
492                    .ok_or_else(|| format!("TextDelta: invalid content_index {content_index}"))?;
493                match block {
494                    ContentBlock::Text { text } => text.push_str(&delta),
495                    _ => {
496                        return Err(format!(
497                            "TextDelta: block at index {content_index} is not Text"
498                        ));
499                    }
500                }
501            }
502
503            AssistantMessageEvent::TextEnd { content_index } => {
504                let blocks = content.as_ref().ok_or("TextEnd before Start")?;
505                let block = blocks
506                    .get(content_index)
507                    .ok_or_else(|| format!("TextEnd: invalid content_index {content_index}"))?;
508                if !matches!(block, ContentBlock::Text { .. }) {
509                    return Err(format!(
510                        "TextEnd: block at index {content_index} is not Text"
511                    ));
512                }
513                ensure_block_open(&open_blocks, content_index, "TextEnd")?;
514                if let Some(open) = open_blocks.get_mut(content_index) {
515                    *open = false;
516                }
517            }
518
519            AssistantMessageEvent::ThinkingStart { content_index } => {
520                let blocks = content.as_mut().ok_or("ThinkingStart before Start")?;
521                if content_index != blocks.len() {
522                    return Err(format!(
523                        "ThinkingStart content_index {content_index} != content length {}",
524                        blocks.len()
525                    ));
526                }
527                blocks.push(ContentBlock::Thinking {
528                    thinking: String::new(),
529                    signature: None,
530                });
531                open_blocks.push(true);
532            }
533
534            AssistantMessageEvent::ThinkingDelta {
535                content_index,
536                delta,
537            } => {
538                let blocks = content.as_mut().ok_or("ThinkingDelta before Start")?;
539                ensure_block_open(&open_blocks, content_index, "ThinkingDelta")?;
540                let block = blocks.get_mut(content_index).ok_or_else(|| {
541                    format!("ThinkingDelta: invalid content_index {content_index}")
542                })?;
543                match block {
544                    ContentBlock::Thinking { thinking, .. } => thinking.push_str(&delta),
545                    _ => {
546                        return Err(format!(
547                            "ThinkingDelta: block at index {content_index} is not Thinking"
548                        ));
549                    }
550                }
551            }
552
553            AssistantMessageEvent::ThinkingEnd {
554                content_index,
555                signature,
556            } => {
557                let blocks = content.as_mut().ok_or("ThinkingEnd before Start")?;
558                ensure_block_open(&open_blocks, content_index, "ThinkingEnd")?;
559                let block = blocks
560                    .get_mut(content_index)
561                    .ok_or_else(|| format!("ThinkingEnd: invalid content_index {content_index}"))?;
562                match block {
563                    ContentBlock::Thinking { signature: sig, .. } => *sig = signature,
564                    _ => {
565                        return Err(format!(
566                            "ThinkingEnd: block at index {content_index} is not Thinking"
567                        ));
568                    }
569                }
570                if let Some(open) = open_blocks.get_mut(content_index) {
571                    *open = false;
572                }
573            }
574
575            AssistantMessageEvent::ToolCallStart {
576                content_index,
577                id,
578                name,
579            } => {
580                let blocks = content.as_mut().ok_or("ToolCallStart before Start")?;
581                if content_index != blocks.len() {
582                    return Err(format!(
583                        "ToolCallStart content_index {content_index} != content length {}",
584                        blocks.len()
585                    ));
586                }
587                blocks.push(ContentBlock::ToolCall {
588                    id,
589                    name,
590                    arguments: Value::Null,
591                    partial_json: Some(String::new()),
592                });
593                open_blocks.push(true);
594            }
595
596            AssistantMessageEvent::ToolCallDelta {
597                content_index,
598                delta,
599            } => {
600                let blocks = content.as_mut().ok_or("ToolCallDelta before Start")?;
601                ensure_block_open(&open_blocks, content_index, "ToolCallDelta")?;
602                let block = blocks.get_mut(content_index).ok_or_else(|| {
603                    format!("ToolCallDelta: invalid content_index {content_index}")
604                })?;
605                match block {
606                    ContentBlock::ToolCall { partial_json, .. } => {
607                        let pj = partial_json
608                            .as_mut()
609                            .ok_or("ToolCallDelta: partial_json already consumed")?;
610                        pj.push_str(&delta);
611                    }
612                    _ => {
613                        return Err(format!(
614                            "ToolCallDelta: block at index {content_index} is not ToolCall"
615                        ));
616                    }
617                }
618            }
619
620            AssistantMessageEvent::ToolCallEnd { content_index } => {
621                let blocks = content.as_mut().ok_or("ToolCallEnd before Start")?;
622                let block = blocks
623                    .get_mut(content_index)
624                    .ok_or_else(|| format!("ToolCallEnd: invalid content_index {content_index}"))?;
625                ensure_block_open(&open_blocks, content_index, "ToolCallEnd")?;
626                match block {
627                    ContentBlock::ToolCall {
628                        arguments,
629                        partial_json,
630                        ..
631                    } => {
632                        let json_str = partial_json
633                            .as_ref()
634                            .ok_or("ToolCallEnd: partial_json already consumed")?
635                            .clone();
636                        if json_str.is_empty() {
637                            *arguments = Value::Object(serde_json::Map::new());
638                            *partial_json = None;
639                        } else {
640                            match serde_json::from_str::<Value>(&json_str) {
641                                Ok(v) => {
642                                    *arguments = v;
643                                    *partial_json = None;
644                                }
645                                Err(e) => {
646                                    if tolerate_truncated_tool_args {
647                                        // Leave `partial_json` as Some so the
648                                        // block is flagged incomplete and the
649                                        // loop recovers on the next turn.
650                                    } else {
651                                        return Err(format!(
652                                            "ToolCallEnd: failed to parse arguments JSON: {e}"
653                                        ));
654                                    }
655                                }
656                            }
657                        }
658                    }
659                    _ => {
660                        return Err(format!(
661                            "ToolCallEnd: block at index {content_index} is not ToolCall"
662                        ));
663                    }
664                }
665                if let Some(open) = open_blocks.get_mut(content_index) {
666                    *open = false;
667                }
668            }
669
670            AssistantMessageEvent::Done {
671                stop_reason: sr,
672                usage: u,
673                cost: c,
674            } => {
675                if let Some(idx) = open_blocks.iter().position(|open| *open) {
676                    let content = content.as_ref().ok_or("Done before Start")?;
677                    if tolerate_truncated_tool_args
678                        && all_open_blocks_are_tool_calls(content, &open_blocks)
679                    {
680                        // Max-tokens truncation: leave open tool-call blocks
681                        // with `partial_json` set so the loop's
682                        // `recover_incomplete_tool_calls` path can convert
683                        // them into error tool results on the next turn.
684                        tracing::debug!(
685                            "Done(Length) with unterminated content block at index {idx} — tolerating for max-tokens recovery"
686                        );
687                    } else {
688                        return Err(format!(
689                            "Done received with unterminated content block at index {idx}"
690                        ));
691                    }
692                }
693                stop_reason = Some(sr);
694                usage = Some(u);
695                cost = Some(c);
696                saw_terminal = true;
697            }
698
699            AssistantMessageEvent::Error {
700                stop_reason: sr,
701                error_message: em,
702                usage: u,
703                error_kind: ek,
704            } => {
705                stop_reason = Some(sr);
706                error_message = Some(em);
707                error_kind = ek;
708                if let Some(u) = u {
709                    usage = Some(u);
710                }
711                saw_terminal = true;
712            }
713        }
714    }
715
716    let content = content.ok_or("no Start event found")?;
717    let stop_reason = stop_reason.ok_or("no terminal event (Done or Error) found")?;
718
719    let timestamp = crate::util::now_timestamp();
720
721    Ok(AssistantMessage {
722        content,
723        provider: provider.to_owned(),
724        model_id: model_id.to_owned(),
725        usage: usage.unwrap_or_default(),
726        cost: cost.unwrap_or_default(),
727        stop_reason,
728        error_message,
729        error_kind,
730        timestamp,
731        cache_hint: None,
732    })
733}
734
735// ─── Compile-time Send + Sync assertions ─────────────────────────────────────
736
737const _: () = {
738    const fn assert_send_sync<T: Send + Sync>() {}
739
740    assert_send_sync::<StreamErrorKind>();
741    assert_send_sync::<StreamTransport>();
742    assert_send_sync::<StreamOptions>();
743    assert_send_sync::<AssistantMessageEvent>();
744    assert_send_sync::<AssistantMessageDelta>();
745};
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750
751    #[test]
752    fn done_with_unterminated_text_block_is_rejected() {
753        // Regression for #206: a Text block opened but never closed before Done
754        // must not silently produce a corrupt assistant message.
755        let events = vec![
756            AssistantMessageEvent::Start,
757            AssistantMessageEvent::TextStart { content_index: 0 },
758            AssistantMessageEvent::TextDelta {
759                content_index: 0,
760                delta: "hi".into(),
761            },
762            AssistantMessageEvent::Done {
763                stop_reason: StopReason::Stop,
764                usage: Usage::default(),
765                cost: Cost::default(),
766            },
767        ];
768        let err = accumulate_message(events, "test", "test").unwrap_err();
769        assert!(err.contains("unterminated content block"), "got: {err}");
770    }
771
772    #[test]
773    fn done_with_unterminated_tool_call_block_is_rejected() {
774        // Regression for #206: missing ToolCallEnd must be rejected.
775        let events = vec![
776            AssistantMessageEvent::Start,
777            AssistantMessageEvent::ToolCallStart {
778                content_index: 0,
779                id: "t1".into(),
780                name: "foo".into(),
781            },
782            AssistantMessageEvent::ToolCallDelta {
783                content_index: 0,
784                delta: "{}".into(),
785            },
786            AssistantMessageEvent::Done {
787                stop_reason: StopReason::ToolUse,
788                usage: Usage::default(),
789                cost: Cost::default(),
790            },
791        ];
792        let err = accumulate_message(events, "test", "test").unwrap_err();
793        assert!(err.contains("unterminated content block"), "got: {err}");
794    }
795
796    #[test]
797    fn done_with_all_blocks_terminated_succeeds() {
798        let events = vec![
799            AssistantMessageEvent::Start,
800            AssistantMessageEvent::TextStart { content_index: 0 },
801            AssistantMessageEvent::TextDelta {
802                content_index: 0,
803                delta: "ok".into(),
804            },
805            AssistantMessageEvent::TextEnd { content_index: 0 },
806            AssistantMessageEvent::Done {
807                stop_reason: StopReason::Stop,
808                usage: Usage::default(),
809                cost: Cost::default(),
810            },
811        ];
812        let msg = accumulate_message(events, "test", "test").expect("should succeed");
813        assert_eq!(msg.content.len(), 1);
814    }
815
816    #[test]
817    fn error_with_unterminated_block_is_allowed() {
818        // Errors may legitimately abort mid-block; don't mask them with a
819        // validation failure.
820        let events = vec![
821            AssistantMessageEvent::Start,
822            AssistantMessageEvent::TextStart { content_index: 0 },
823            AssistantMessageEvent::Error {
824                stop_reason: StopReason::Error,
825                error_message: "boom".into(),
826                usage: None,
827                error_kind: None,
828            },
829        ];
830        let msg = accumulate_message(events, "test", "test").expect("error terminal ok");
831        assert_eq!(msg.error_message.as_deref(), Some("boom"));
832    }
833
834    #[test]
835    fn error_constructor_sets_kind_none() {
836        let event = AssistantMessageEvent::error("boom");
837        match event {
838            AssistantMessageEvent::Error { error_kind, .. } => {
839                assert_eq!(error_kind, None);
840            }
841            other => panic!("expected Error, got {other:?}"),
842        }
843    }
844
845    #[test]
846    fn error_throttled_constructor_sets_kind() {
847        let event = AssistantMessageEvent::error_throttled("rate limited");
848        match event {
849            AssistantMessageEvent::Error {
850                error_kind,
851                error_message,
852                ..
853            } => {
854                assert_eq!(error_kind, Some(StreamErrorKind::Throttled));
855                assert_eq!(error_message, "rate limited");
856            }
857            other => panic!("expected Error, got {other:?}"),
858        }
859    }
860
861    #[test]
862    fn error_context_overflow_constructor_sets_kind() {
863        let event = AssistantMessageEvent::error_context_overflow("too long");
864        match event {
865            AssistantMessageEvent::Error { error_kind, .. } => {
866                assert_eq!(error_kind, Some(StreamErrorKind::ContextWindowExceeded));
867            }
868            other => panic!("expected Error, got {other:?}"),
869        }
870    }
871
872    #[test]
873    fn error_auth_constructor_sets_kind() {
874        let event = AssistantMessageEvent::error_auth("bad key");
875        match event {
876            AssistantMessageEvent::Error { error_kind, .. } => {
877                assert_eq!(error_kind, Some(StreamErrorKind::Auth));
878            }
879            other => panic!("expected Error, got {other:?}"),
880        }
881    }
882
883    #[test]
884    fn error_network_constructor_sets_kind() {
885        let event = AssistantMessageEvent::error_network("timeout");
886        match event {
887            AssistantMessageEvent::Error { error_kind, .. } => {
888                assert_eq!(error_kind, Some(StreamErrorKind::Network));
889            }
890            other => panic!("expected Error, got {other:?}"),
891        }
892    }
893
894    #[test]
895    fn error_content_filtered_constructor_sets_kind() {
896        let event = AssistantMessageEvent::error_content_filtered("blocked by safety filter");
897        match event {
898            AssistantMessageEvent::Error {
899                error_kind,
900                error_message,
901                ..
902            } => {
903                assert_eq!(error_kind, Some(StreamErrorKind::ContentFiltered));
904                assert_eq!(error_message, "blocked by safety filter");
905            }
906            other => panic!("expected Error, got {other:?}"),
907        }
908    }
909
910    #[test]
911    fn text_response_produces_valid_event_sequence() {
912        let events = AssistantMessageEvent::text_response("hello world");
913        assert_eq!(events.len(), 5);
914        assert!(matches!(events[0], AssistantMessageEvent::Start));
915        assert!(matches!(
916            events[1],
917            AssistantMessageEvent::TextStart { content_index: 0 }
918        ));
919        match &events[2] {
920            AssistantMessageEvent::TextDelta {
921                content_index,
922                delta,
923            } => {
924                assert_eq!(*content_index, 0);
925                assert_eq!(delta, "hello world");
926            }
927            other => panic!("expected TextDelta, got {other:?}"),
928        }
929        assert!(matches!(
930            events[3],
931            AssistantMessageEvent::TextEnd { content_index: 0 }
932        ));
933        assert!(matches!(
934            events[4],
935            AssistantMessageEvent::Done {
936                stop_reason: StopReason::Stop,
937                ..
938            }
939        ));
940    }
941
942    // Regression for #293: Done(Length) with an unterminated tool-call block
943    // must NOT be rejected — the block should survive with `partial_json` set
944    // so `recover_incomplete_tool_calls` can convert it to an error result.
945    #[test]
946    fn done_length_with_unterminated_tool_call_is_tolerated() {
947        let events = vec![
948            AssistantMessageEvent::Start,
949            AssistantMessageEvent::ToolCallStart {
950                content_index: 0,
951                id: "tc_1".into(),
952                name: "read_file".into(),
953            },
954            AssistantMessageEvent::ToolCallDelta {
955                content_index: 0,
956                delta: r#"{"path": "/tmp"#.into(),
957            },
958            AssistantMessageEvent::Done {
959                stop_reason: StopReason::Length,
960                usage: Usage::default(),
961                cost: Cost::default(),
962            },
963        ];
964        let msg = accumulate_message(events, "test", "test")
965            .expect("Done(Length) with open tool-call block should succeed");
966        assert_eq!(msg.stop_reason, StopReason::Length);
967        // The tool call block should have partial_json set (incomplete)
968        match &msg.content[0] {
969            ContentBlock::ToolCall { partial_json, .. } => {
970                assert!(
971                    partial_json.is_some(),
972                    "partial_json should be Some for incomplete tool call"
973                );
974            }
975            other => panic!("expected ToolCall, got {other:?}"),
976        }
977    }
978
979    #[test]
980    fn done_length_with_unterminated_text_block_is_rejected() {
981        let events = vec![
982            AssistantMessageEvent::Start,
983            AssistantMessageEvent::TextStart { content_index: 0 },
984            AssistantMessageEvent::TextDelta {
985                content_index: 0,
986                delta: "partial".into(),
987            },
988            AssistantMessageEvent::Done {
989                stop_reason: StopReason::Length,
990                usage: Usage::default(),
991                cost: Cost::default(),
992            },
993        ];
994
995        let err = accumulate_message(events, "test", "test").unwrap_err();
996        assert!(err.contains("unterminated content block"), "got: {err}");
997    }
998
999    #[test]
1000    fn done_length_with_unterminated_thinking_block_is_rejected() {
1001        let events = vec![
1002            AssistantMessageEvent::Start,
1003            AssistantMessageEvent::ThinkingStart { content_index: 0 },
1004            AssistantMessageEvent::ThinkingDelta {
1005                content_index: 0,
1006                delta: "partial".into(),
1007            },
1008            AssistantMessageEvent::Done {
1009                stop_reason: StopReason::Length,
1010                usage: Usage::default(),
1011                cost: Cost::default(),
1012            },
1013        ];
1014
1015        let err = accumulate_message(events, "test", "test").unwrap_err();
1016        assert!(err.contains("unterminated content block"), "got: {err}");
1017    }
1018
1019    #[test]
1020    fn text_response_accumulates_correctly() {
1021        let events = AssistantMessageEvent::text_response("accumulated text");
1022        let msg = accumulate_message(events, "test", "test-model").expect("accumulation failed");
1023        assert_eq!(msg.content.len(), 1);
1024        assert_eq!(ContentBlock::extract_text(&msg.content), "accumulated text");
1025        assert_eq!(msg.stop_reason, StopReason::Stop);
1026    }
1027
1028    #[test]
1029    fn text_delta_after_text_end_is_rejected() {
1030        let events = vec![
1031            AssistantMessageEvent::Start,
1032            AssistantMessageEvent::TextStart { content_index: 0 },
1033            AssistantMessageEvent::TextDelta {
1034                content_index: 0,
1035                delta: "hello".into(),
1036            },
1037            AssistantMessageEvent::TextEnd { content_index: 0 },
1038            AssistantMessageEvent::TextDelta {
1039                content_index: 0,
1040                delta: " again".into(),
1041            },
1042            AssistantMessageEvent::Done {
1043                stop_reason: StopReason::Stop,
1044                usage: Usage::default(),
1045                cost: Cost::default(),
1046            },
1047        ];
1048
1049        let err = accumulate_message(events, "test", "test").unwrap_err();
1050        assert_eq!(err, "TextDelta: block at index 0 is already closed");
1051    }
1052
1053    #[test]
1054    fn duplicate_text_end_is_rejected() {
1055        let events = vec![
1056            AssistantMessageEvent::Start,
1057            AssistantMessageEvent::TextStart { content_index: 0 },
1058            AssistantMessageEvent::TextDelta {
1059                content_index: 0,
1060                delta: "hello".into(),
1061            },
1062            AssistantMessageEvent::TextEnd { content_index: 0 },
1063            AssistantMessageEvent::TextEnd { content_index: 0 },
1064            AssistantMessageEvent::Done {
1065                stop_reason: StopReason::Stop,
1066                usage: Usage::default(),
1067                cost: Cost::default(),
1068            },
1069        ];
1070
1071        let err = accumulate_message(events, "test", "test").unwrap_err();
1072        assert_eq!(err, "TextEnd: block at index 0 is already closed");
1073    }
1074
1075    #[test]
1076    fn duplicate_thinking_end_is_rejected() {
1077        let events = vec![
1078            AssistantMessageEvent::Start,
1079            AssistantMessageEvent::ThinkingStart { content_index: 0 },
1080            AssistantMessageEvent::ThinkingDelta {
1081                content_index: 0,
1082                delta: "step 1".into(),
1083            },
1084            AssistantMessageEvent::ThinkingEnd {
1085                content_index: 0,
1086                signature: Some("sig-1".into()),
1087            },
1088            AssistantMessageEvent::ThinkingEnd {
1089                content_index: 0,
1090                signature: Some("sig-2".into()),
1091            },
1092            AssistantMessageEvent::Done {
1093                stop_reason: StopReason::Stop,
1094                usage: Usage::default(),
1095                cost: Cost::default(),
1096            },
1097        ];
1098
1099        let err = accumulate_message(events, "test", "test").unwrap_err();
1100        assert_eq!(err, "ThinkingEnd: block at index 0 is already closed");
1101    }
1102
1103    #[test]
1104    fn tool_call_delta_after_end_is_rejected() {
1105        let events = vec![
1106            AssistantMessageEvent::Start,
1107            AssistantMessageEvent::ToolCallStart {
1108                content_index: 0,
1109                id: "tool-1".into(),
1110                name: "read_file".into(),
1111            },
1112            AssistantMessageEvent::ToolCallDelta {
1113                content_index: 0,
1114                delta: "{\"path\":\"/tmp/a\"}".into(),
1115            },
1116            AssistantMessageEvent::ToolCallEnd { content_index: 0 },
1117            AssistantMessageEvent::ToolCallDelta {
1118                content_index: 0,
1119                delta: ",\"extra\":true}".into(),
1120            },
1121            AssistantMessageEvent::Done {
1122                stop_reason: StopReason::ToolUse,
1123                usage: Usage::default(),
1124                cost: Cost::default(),
1125            },
1126        ];
1127
1128        let err = accumulate_message(events, "test", "test").unwrap_err();
1129        assert_eq!(err, "ToolCallDelta: block at index 0 is already closed");
1130    }
1131
1132    #[test]
1133    fn duplicate_tool_call_end_is_rejected() {
1134        let events = vec![
1135            AssistantMessageEvent::Start,
1136            AssistantMessageEvent::ToolCallStart {
1137                content_index: 0,
1138                id: "tool-1".into(),
1139                name: "read_file".into(),
1140            },
1141            AssistantMessageEvent::ToolCallDelta {
1142                content_index: 0,
1143                delta: "{\"path\":\"/tmp/a\"}".into(),
1144            },
1145            AssistantMessageEvent::ToolCallEnd { content_index: 0 },
1146            AssistantMessageEvent::ToolCallEnd { content_index: 0 },
1147            AssistantMessageEvent::Done {
1148                stop_reason: StopReason::ToolUse,
1149                usage: Usage::default(),
1150                cost: Cost::default(),
1151            },
1152        ];
1153
1154        let err = accumulate_message(events, "test", "test").unwrap_err();
1155        assert_eq!(err, "ToolCallEnd: block at index 0 is already closed");
1156    }
1157
1158    // ── sanitize_incomplete_tool_calls (#619) ──────────────────────────────
1159
1160    fn build_assistant_with_tool_call(
1161        arguments: Value,
1162        partial_json: Option<String>,
1163    ) -> AssistantMessage {
1164        AssistantMessage {
1165            content: vec![ContentBlock::ToolCall {
1166                id: "tc_1".into(),
1167                name: "read_file".into(),
1168                arguments,
1169                partial_json,
1170            }],
1171            provider: "test".into(),
1172            model_id: "test".into(),
1173            usage: Usage::default(),
1174            cost: Cost::default(),
1175            stop_reason: StopReason::Length,
1176            error_message: None,
1177            error_kind: None,
1178            timestamp: 0,
1179            cache_hint: None,
1180        }
1181    }
1182
1183    #[test]
1184    fn sanitize_null_arguments_with_partial_json_returns_empty_object() {
1185        let mut msg = build_assistant_with_tool_call(Value::Null, Some("{\"path\": \"/tm".into()));
1186        let fixed = sanitize_incomplete_tool_calls(&mut msg);
1187        assert_eq!(fixed, 1);
1188        match &msg.content[0] {
1189            ContentBlock::ToolCall {
1190                arguments,
1191                partial_json,
1192                ..
1193            } => {
1194                assert_eq!(*arguments, Value::Object(serde_json::Map::new()));
1195                assert!(
1196                    partial_json.is_none(),
1197                    "partial_json must be cleared after scrub"
1198                );
1199            }
1200            other => panic!("expected ToolCall, got {other:?}"),
1201        }
1202    }
1203
1204    #[test]
1205    fn sanitize_leaves_valid_object_arguments_untouched() {
1206        let args = serde_json::json!({ "path": "/tmp/a" });
1207        let mut msg = build_assistant_with_tool_call(args.clone(), None);
1208        let fixed = sanitize_incomplete_tool_calls(&mut msg);
1209        assert_eq!(fixed, 0);
1210        match &msg.content[0] {
1211            ContentBlock::ToolCall {
1212                arguments,
1213                partial_json,
1214                ..
1215            } => {
1216                assert_eq!(*arguments, args);
1217                assert!(partial_json.is_none());
1218            }
1219            other => panic!("expected ToolCall, got {other:?}"),
1220        }
1221    }
1222
1223    #[test]
1224    fn sanitize_coerces_non_object_arguments() {
1225        // `Value::String` / arrays / numbers are all not objects — they would
1226        // confuse downstream providers even if `partial_json` is absent.
1227        let mut msg = build_assistant_with_tool_call(Value::String("truncated".into()), None);
1228        let fixed = sanitize_incomplete_tool_calls(&mut msg);
1229        assert_eq!(fixed, 1);
1230        match &msg.content[0] {
1231            ContentBlock::ToolCall { arguments, .. } => {
1232                assert_eq!(*arguments, Value::Object(serde_json::Map::new()));
1233            }
1234            other => panic!("expected ToolCall, got {other:?}"),
1235        }
1236    }
1237
1238    #[test]
1239    fn sanitize_is_idempotent() {
1240        let mut msg = build_assistant_with_tool_call(Value::Null, Some("{\"path\":".into()));
1241        assert_eq!(sanitize_incomplete_tool_calls(&mut msg), 1);
1242        // A second pass should be a no-op.
1243        assert_eq!(sanitize_incomplete_tool_calls(&mut msg), 0);
1244    }
1245
1246    #[test]
1247    fn sanitize_preserves_non_tool_blocks() {
1248        let mut msg = AssistantMessage {
1249            content: vec![
1250                ContentBlock::Text {
1251                    text: "hello".into(),
1252                },
1253                ContentBlock::ToolCall {
1254                    id: "tc_1".into(),
1255                    name: "foo".into(),
1256                    arguments: Value::Null,
1257                    partial_json: Some("{".into()),
1258                },
1259                ContentBlock::Text {
1260                    text: "world".into(),
1261                },
1262            ],
1263            provider: "test".into(),
1264            model_id: "test".into(),
1265            usage: Usage::default(),
1266            cost: Cost::default(),
1267            stop_reason: StopReason::Length,
1268            error_message: None,
1269            error_kind: None,
1270            timestamp: 0,
1271            cache_hint: None,
1272        };
1273        let fixed = sanitize_incomplete_tool_calls(&mut msg);
1274        assert_eq!(fixed, 1);
1275        // Text blocks preserved in place.
1276        match &msg.content[0] {
1277            ContentBlock::Text { text } => assert_eq!(text, "hello"),
1278            other => panic!("expected Text, got {other:?}"),
1279        }
1280        match &msg.content[2] {
1281            ContentBlock::Text { text } => assert_eq!(text, "world"),
1282            other => panic!("expected Text, got {other:?}"),
1283        }
1284    }
1285
1286    /// Regression for #619: the canned "Length + partial tool-call" stream from
1287    /// the issue, when run through `accumulate_message` and then scrubbed,
1288    /// produces a block suitable for replay to any provider adapter.
1289    #[test]
1290    fn accumulate_plus_sanitize_yields_adapter_safe_tool_call() {
1291        let events = vec![
1292            AssistantMessageEvent::Start,
1293            AssistantMessageEvent::ToolCallStart {
1294                content_index: 0,
1295                id: "tc_1".into(),
1296                name: "read_file".into(),
1297            },
1298            AssistantMessageEvent::ToolCallDelta {
1299                content_index: 0,
1300                delta: r#"{"path": "/tm"#.into(),
1301            },
1302            AssistantMessageEvent::Done {
1303                stop_reason: StopReason::Length,
1304                usage: Usage::default(),
1305                cost: Cost::default(),
1306            },
1307        ];
1308        let mut msg = accumulate_message(events, "test", "test")
1309            .expect("Done(Length) with unterminated tool-call should accumulate");
1310        // Pre-scrub: partial_json present, arguments null.
1311        match &msg.content[0] {
1312            ContentBlock::ToolCall {
1313                arguments,
1314                partial_json,
1315                ..
1316            } => {
1317                assert!(partial_json.is_some());
1318                assert!(arguments.is_null());
1319            }
1320            other => panic!("expected ToolCall, got {other:?}"),
1321        }
1322
1323        sanitize_incomplete_tool_calls(&mut msg);
1324
1325        // Post-scrub: arguments is an empty object, partial_json cleared.
1326        match &msg.content[0] {
1327            ContentBlock::ToolCall {
1328                arguments,
1329                partial_json,
1330                ..
1331            } => {
1332                assert!(arguments.is_object());
1333                assert_eq!(arguments.as_object().unwrap().len(), 0);
1334                assert!(partial_json.is_none());
1335            }
1336            other => panic!("expected ToolCall, got {other:?}"),
1337        }
1338    }
1339}