Skip to main content

st/proxy/claude/
stream.rs

1//! SSE Streaming Parser for the Claude Messages API
2//!
3//! The Claude API sends Server-Sent Events (SSE) when `stream: true`.
4//! Format: `event: <name>\ndata: <json>\n\n`
5//!
6//! This module provides:
7//! - `SseParser`: reads raw bytes from reqwest and yields typed `StreamEvent`s
8//! - `MessageAccumulator`: builds a complete `MessagesResponse` from events
9//!
10//! The tricky part: bytes arrive in arbitrary chunks that may split mid-line.
11//! We buffer until we see `\n\n` (end of SSE event), then parse.
12
13use super::error::ClaudeApiError;
14use super::types::{ContentBlock, MessagesResponse, StopReason, Usage};
15use futures_util::StreamExt;
16use serde::Deserialize;
17use std::pin::Pin;
18
19// ---------------------------------------------------------------------------
20// Stream event types (deserialized from SSE `data:` payloads)
21// ---------------------------------------------------------------------------
22
23/// All possible SSE event types from the Claude streaming API.
24/// Each variant matches an `event: <name>` line in the SSE stream.
25#[derive(Debug, Clone, Deserialize)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum StreamEvent {
28    /// First event - contains the message skeleton (id, model, usage)
29    MessageStart { message: MessagesResponse },
30    /// A new content block is starting at the given index
31    ContentBlockStart {
32        index: usize,
33        content_block: ContentBlock,
34    },
35    /// Incremental content for the block at the given index
36    ContentBlockDelta { index: usize, delta: ContentDelta },
37    /// The block at the given index is complete
38    ContentBlockStop { index: usize },
39    /// Message-level update (stop_reason, final usage)
40    MessageDelta {
41        delta: MessageDeltaPayload,
42        #[serde(skip_serializing_if = "Option::is_none")]
43        usage: Option<Usage>,
44    },
45    /// Message is complete - stream will end after this
46    MessageStop,
47    /// Keepalive ping (ignore)
48    Ping,
49    /// Server-side error delivered via the stream
50    Error { error: super::error::ApiErrorBody },
51}
52
53/// Incremental content within a `content_block_delta` event
54#[derive(Debug, Clone, Deserialize)]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum ContentDelta {
57    /// Incremental text (append to current text block)
58    TextDelta { text: String },
59    /// Incremental thinking (append to current thinking block)
60    ThinkingDelta { thinking: String },
61    /// Incremental signature for thinking block verification
62    SignatureDelta { signature: String },
63    /// Incremental JSON for tool input (append and parse when block stops)
64    InputJsonDelta { partial_json: String },
65}
66
67/// Top-level message metadata update
68#[derive(Debug, Clone, Deserialize)]
69pub struct MessageDeltaPayload {
70    pub stop_reason: Option<StopReason>,
71    pub stop_sequence: Option<String>,
72}
73
74// ---------------------------------------------------------------------------
75// SSE Parser - reads bytes, yields StreamEvents
76// ---------------------------------------------------------------------------
77
78/// Parses SSE events from a reqwest byte stream.
79///
80/// # Example
81/// ```rust,no_run
82/// let response = client.post(url).send().await?;
83/// let mut parser = SseParser::new(response);
84/// while let Some(event) = parser.next_event().await? {
85///     match event {
86///         StreamEvent::ContentBlockDelta { delta, .. } => { /* handle */ }
87///         StreamEvent::MessageStop => break,
88///         _ => {}
89///     }
90/// }
91/// ```
92pub struct SseParser {
93    /// The raw byte stream from reqwest
94    stream: Pin<Box<dyn futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
95    /// Accumulated text that hasn't been parsed into an event yet
96    buffer: String,
97    /// Whether the stream has ended
98    done: bool,
99}
100
101impl SseParser {
102    /// Wrap a reqwest response in an SSE parser
103    pub fn new(response: reqwest::Response) -> Self {
104        Self {
105            stream: Box::pin(response.bytes_stream()),
106            buffer: String::new(),
107            done: false,
108        }
109    }
110
111    /// Read the next complete SSE event from the stream.
112    /// Returns `None` when the stream ends.
113    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ClaudeApiError> {
114        if self.done {
115            return Ok(None);
116        }
117
118        loop {
119            // Check if we already have a complete event in the buffer
120            // SSE events are delimited by blank lines (\n\n)
121            if let Some(event) = self.try_parse_event()? {
122                return Ok(Some(event));
123            }
124
125            // Need more data from the stream
126            match self.stream.next().await {
127                Some(Ok(bytes)) => {
128                    // Append raw bytes to our text buffer
129                    let text = String::from_utf8_lossy(&bytes);
130                    self.buffer.push_str(&text);
131                }
132                Some(Err(e)) => {
133                    self.done = true;
134                    return Err(ClaudeApiError::Network(e));
135                }
136                None => {
137                    // Stream ended - try to parse any remaining data
138                    self.done = true;
139                    return self.try_parse_event();
140                }
141            }
142        }
143    }
144
145    /// Try to extract one complete SSE event from the buffer.
146    /// An event ends with a blank line (\n\n). Each event has:
147    /// - `event: <name>` line (the event type)
148    /// - `data: <json>` line (the payload, may span multiple lines)
149    fn try_parse_event(&mut self) -> Result<Option<StreamEvent>, ClaudeApiError> {
150        // Find the next complete event (double newline boundary)
151        let boundary = match self.buffer.find("\n\n") {
152            Some(pos) => pos,
153            None => return Ok(None),
154        };
155
156        // Extract the raw event text and remove it from the buffer
157        let raw_event = self.buffer[..boundary].to_string();
158        self.buffer = self.buffer[boundary + 2..].to_string();
159
160        // Parse the SSE fields
161        let mut event_name = String::new();
162        let mut data_lines = Vec::new();
163
164        for line in raw_event.lines() {
165            if let Some(name) = line.strip_prefix("event: ") {
166                event_name = name.trim().to_string();
167            } else if let Some(data) = line.strip_prefix("data: ") {
168                data_lines.push(data);
169            } else if let Some(stripped) = line.strip_prefix("data:") {
170                // `data:` with no space (empty data)
171                data_lines.push(stripped);
172            }
173            // Ignore other lines (comments starting with :, etc.)
174        }
175
176        // Skip events with no data (like keepalive comments)
177        if data_lines.is_empty() {
178            return Ok(None);
179        }
180
181        let data = data_lines.join("\n");
182
183        // Parse the JSON payload based on event type
184        // The Claude API sets the `type` field in the JSON to match the event name,
185        // so we can use serde's tagged enum directly
186        match serde_json::from_str::<StreamEvent>(&data) {
187            Ok(event) => Ok(Some(event)),
188            Err(e) => {
189                // If we can't parse it, log context and return error
190                Err(ClaudeApiError::StreamError {
191                    message: format!(
192                        "Failed to parse SSE event '{}': {} (data: {})",
193                        event_name, e, data
194                    ),
195                })
196            }
197        }
198    }
199}
200
201// ---------------------------------------------------------------------------
202// Message Accumulator - builds complete response from stream events
203// ---------------------------------------------------------------------------
204
205/// Accumulates streaming events into a complete `MessagesResponse`.
206///
207/// Useful when you want streaming for timeout protection but need the
208/// final assembled message (like reqwest's `.get_final_message()` pattern).
209#[derive(Default)]
210pub struct MessageAccumulator {
211    /// The response skeleton from message_start
212    response: Option<MessagesResponse>,
213    /// Content blocks being built up from deltas
214    blocks: Vec<BlockBuilder>,
215}
216
217/// Internal: tracks a content block being assembled from deltas
218struct BlockBuilder {
219    text: String,
220    thinking: String,
221    tool_id: String,
222    tool_name: String,
223    partial_json: String,
224    signature: String,
225    is_thinking: bool,
226    is_tool_use: bool,
227}
228
229impl BlockBuilder {
230    fn new_text() -> Self {
231        Self {
232            text: String::new(),
233            thinking: String::new(),
234            tool_id: String::new(),
235            tool_name: String::new(),
236            partial_json: String::new(),
237            signature: String::new(),
238            is_thinking: false,
239            is_tool_use: false,
240        }
241    }
242
243    fn new_thinking() -> Self {
244        Self {
245            is_thinking: true,
246            ..Self::new_text()
247        }
248    }
249
250    fn new_tool(id: String, name: String) -> Self {
251        Self {
252            tool_id: id,
253            tool_name: name,
254            is_tool_use: true,
255            ..Self::new_text()
256        }
257    }
258
259    /// Convert the accumulated data into a final ContentBlock
260    fn finish(self) -> ContentBlock {
261        if self.is_thinking {
262            ContentBlock::Thinking {
263                thinking: self.thinking,
264                signature: if self.signature.is_empty() {
265                    None
266                } else {
267                    Some(self.signature)
268                },
269            }
270        } else if self.is_tool_use {
271            let input = serde_json::from_str(&self.partial_json)
272                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
273            ContentBlock::ToolUse {
274                id: self.tool_id,
275                name: self.tool_name,
276                input,
277            }
278        } else {
279            ContentBlock::Text {
280                text: self.text,
281                cache_control: None,
282            }
283        }
284    }
285}
286
287impl MessageAccumulator {
288    pub fn new() -> Self {
289        Self::default()
290    }
291
292    /// Process a single stream event, updating internal state
293    pub fn process_event(&mut self, event: &StreamEvent) {
294        match event {
295            StreamEvent::MessageStart { message } => {
296                self.response = Some(message.clone());
297            }
298            StreamEvent::ContentBlockStart { content_block, .. } => match content_block {
299                ContentBlock::Thinking { .. } => self.blocks.push(BlockBuilder::new_thinking()),
300                ContentBlock::ToolUse { id, name, .. } => {
301                    self.blocks
302                        .push(BlockBuilder::new_tool(id.clone(), name.clone()));
303                }
304                _ => self.blocks.push(BlockBuilder::new_text()),
305            },
306            StreamEvent::ContentBlockDelta { index, delta } => {
307                if let Some(block) = self.blocks.get_mut(*index) {
308                    match delta {
309                        ContentDelta::TextDelta { text } => block.text.push_str(text),
310                        ContentDelta::ThinkingDelta { thinking } => {
311                            block.thinking.push_str(thinking)
312                        }
313                        ContentDelta::InputJsonDelta { partial_json } => {
314                            block.partial_json.push_str(partial_json);
315                        }
316                        ContentDelta::SignatureDelta { signature } => {
317                            block.signature.push_str(signature);
318                        }
319                    }
320                }
321            }
322            StreamEvent::MessageDelta { delta, usage } => {
323                if let Some(ref mut resp) = self.response {
324                    resp.stop_reason = delta.stop_reason.clone();
325                    resp.stop_sequence = delta.stop_sequence.clone();
326                    if let Some(u) = usage {
327                        resp.usage.output_tokens = u.output_tokens;
328                    }
329                }
330            }
331            // ContentBlockStop, MessageStop, Ping - no accumulation needed
332            _ => {}
333        }
334    }
335
336    /// Finalize and return the complete `MessagesResponse`
337    pub fn finish(mut self) -> Result<MessagesResponse, ClaudeApiError> {
338        let mut response = self
339            .response
340            .take()
341            .ok_or_else(|| ClaudeApiError::StreamError {
342                message: "Stream ended without message_start event".to_string(),
343            })?;
344
345        // Replace the skeleton content blocks with our accumulated ones
346        response.content = self.blocks.into_iter().map(|b| b.finish()).collect();
347
348        Ok(response)
349    }
350}