st/proxy/claude/stream.rs
1//! SSE Streaming Parser for the Claude Messages API
2//!
3//! The Claude API sends Server-Sent Events (SSE) when `stream: true`.
4//! Format: `event: <name>\ndata: <json>\n\n`
5//!
6//! This module provides:
7//! - `SseParser`: reads raw bytes from reqwest and yields typed `StreamEvent`s
8//! - `MessageAccumulator`: builds a complete `MessagesResponse` from events
9//!
10//! The tricky part: bytes arrive in arbitrary chunks that may split mid-line.
11//! We buffer until we see `\n\n` (end of SSE event), then parse.
12
13use super::error::ClaudeApiError;
14use super::types::{ContentBlock, MessagesResponse, StopReason, Usage};
15use futures_util::StreamExt;
16use serde::Deserialize;
17use std::pin::Pin;
18
19// ---------------------------------------------------------------------------
20// Stream event types (deserialized from SSE `data:` payloads)
21// ---------------------------------------------------------------------------
22
23/// All possible SSE event types from the Claude streaming API.
24/// Each variant matches an `event: <name>` line in the SSE stream.
25#[derive(Debug, Clone, Deserialize)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum StreamEvent {
28 /// First event - contains the message skeleton (id, model, usage)
29 MessageStart { message: MessagesResponse },
30 /// A new content block is starting at the given index
31 ContentBlockStart {
32 index: usize,
33 content_block: ContentBlock,
34 },
35 /// Incremental content for the block at the given index
36 ContentBlockDelta { index: usize, delta: ContentDelta },
37 /// The block at the given index is complete
38 ContentBlockStop { index: usize },
39 /// Message-level update (stop_reason, final usage)
40 MessageDelta {
41 delta: MessageDeltaPayload,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 usage: Option<Usage>,
44 },
45 /// Message is complete - stream will end after this
46 MessageStop,
47 /// Keepalive ping (ignore)
48 Ping,
49 /// Server-side error delivered via the stream
50 Error { error: super::error::ApiErrorBody },
51}
52
53/// Incremental content within a `content_block_delta` event
54#[derive(Debug, Clone, Deserialize)]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum ContentDelta {
57 /// Incremental text (append to current text block)
58 TextDelta { text: String },
59 /// Incremental thinking (append to current thinking block)
60 ThinkingDelta { thinking: String },
61 /// Incremental signature for thinking block verification
62 SignatureDelta { signature: String },
63 /// Incremental JSON for tool input (append and parse when block stops)
64 InputJsonDelta { partial_json: String },
65}
66
67/// Top-level message metadata update
68#[derive(Debug, Clone, Deserialize)]
69pub struct MessageDeltaPayload {
70 pub stop_reason: Option<StopReason>,
71 pub stop_sequence: Option<String>,
72}
73
74// ---------------------------------------------------------------------------
75// SSE Parser - reads bytes, yields StreamEvents
76// ---------------------------------------------------------------------------
77
78/// Parses SSE events from a reqwest byte stream.
79///
80/// # Example
81/// ```rust,no_run
82/// let response = client.post(url).send().await?;
83/// let mut parser = SseParser::new(response);
84/// while let Some(event) = parser.next_event().await? {
85/// match event {
86/// StreamEvent::ContentBlockDelta { delta, .. } => { /* handle */ }
87/// StreamEvent::MessageStop => break,
88/// _ => {}
89/// }
90/// }
91/// ```
92pub struct SseParser {
93 /// The raw byte stream from reqwest
94 stream: Pin<Box<dyn futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
95 /// Accumulated text that hasn't been parsed into an event yet
96 buffer: String,
97 /// Whether the stream has ended
98 done: bool,
99}
100
101impl SseParser {
102 /// Wrap a reqwest response in an SSE parser
103 pub fn new(response: reqwest::Response) -> Self {
104 Self {
105 stream: Box::pin(response.bytes_stream()),
106 buffer: String::new(),
107 done: false,
108 }
109 }
110
111 /// Read the next complete SSE event from the stream.
112 /// Returns `None` when the stream ends.
113 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ClaudeApiError> {
114 if self.done {
115 return Ok(None);
116 }
117
118 loop {
119 // Check if we already have a complete event in the buffer
120 // SSE events are delimited by blank lines (\n\n)
121 if let Some(event) = self.try_parse_event()? {
122 return Ok(Some(event));
123 }
124
125 // Need more data from the stream
126 match self.stream.next().await {
127 Some(Ok(bytes)) => {
128 // Append raw bytes to our text buffer
129 let text = String::from_utf8_lossy(&bytes);
130 self.buffer.push_str(&text);
131 }
132 Some(Err(e)) => {
133 self.done = true;
134 return Err(ClaudeApiError::Network(e));
135 }
136 None => {
137 // Stream ended - try to parse any remaining data
138 self.done = true;
139 return self.try_parse_event();
140 }
141 }
142 }
143 }
144
145 /// Try to extract one complete SSE event from the buffer.
146 /// An event ends with a blank line (\n\n). Each event has:
147 /// - `event: <name>` line (the event type)
148 /// - `data: <json>` line (the payload, may span multiple lines)
149 fn try_parse_event(&mut self) -> Result<Option<StreamEvent>, ClaudeApiError> {
150 // Find the next complete event (double newline boundary)
151 let boundary = match self.buffer.find("\n\n") {
152 Some(pos) => pos,
153 None => return Ok(None),
154 };
155
156 // Extract the raw event text and remove it from the buffer
157 let raw_event = self.buffer[..boundary].to_string();
158 self.buffer = self.buffer[boundary + 2..].to_string();
159
160 // Parse the SSE fields
161 let mut event_name = String::new();
162 let mut data_lines = Vec::new();
163
164 for line in raw_event.lines() {
165 if let Some(name) = line.strip_prefix("event: ") {
166 event_name = name.trim().to_string();
167 } else if let Some(data) = line.strip_prefix("data: ") {
168 data_lines.push(data);
169 } else if let Some(stripped) = line.strip_prefix("data:") {
170 // `data:` with no space (empty data)
171 data_lines.push(stripped);
172 }
173 // Ignore other lines (comments starting with :, etc.)
174 }
175
176 // Skip events with no data (like keepalive comments)
177 if data_lines.is_empty() {
178 return Ok(None);
179 }
180
181 let data = data_lines.join("\n");
182
183 // Parse the JSON payload based on event type
184 // The Claude API sets the `type` field in the JSON to match the event name,
185 // so we can use serde's tagged enum directly
186 match serde_json::from_str::<StreamEvent>(&data) {
187 Ok(event) => Ok(Some(event)),
188 Err(e) => {
189 // If we can't parse it, log context and return error
190 Err(ClaudeApiError::StreamError {
191 message: format!(
192 "Failed to parse SSE event '{}': {} (data: {})",
193 event_name, e, data
194 ),
195 })
196 }
197 }
198 }
199}
200
201// ---------------------------------------------------------------------------
202// Message Accumulator - builds complete response from stream events
203// ---------------------------------------------------------------------------
204
205/// Accumulates streaming events into a complete `MessagesResponse`.
206///
207/// Useful when you want streaming for timeout protection but need the
208/// final assembled message (like reqwest's `.get_final_message()` pattern).
209#[derive(Default)]
210pub struct MessageAccumulator {
211 /// The response skeleton from message_start
212 response: Option<MessagesResponse>,
213 /// Content blocks being built up from deltas
214 blocks: Vec<BlockBuilder>,
215}
216
217/// Internal: tracks a content block being assembled from deltas
218struct BlockBuilder {
219 text: String,
220 thinking: String,
221 tool_id: String,
222 tool_name: String,
223 partial_json: String,
224 signature: String,
225 is_thinking: bool,
226 is_tool_use: bool,
227}
228
229impl BlockBuilder {
230 fn new_text() -> Self {
231 Self {
232 text: String::new(),
233 thinking: String::new(),
234 tool_id: String::new(),
235 tool_name: String::new(),
236 partial_json: String::new(),
237 signature: String::new(),
238 is_thinking: false,
239 is_tool_use: false,
240 }
241 }
242
243 fn new_thinking() -> Self {
244 Self {
245 is_thinking: true,
246 ..Self::new_text()
247 }
248 }
249
250 fn new_tool(id: String, name: String) -> Self {
251 Self {
252 tool_id: id,
253 tool_name: name,
254 is_tool_use: true,
255 ..Self::new_text()
256 }
257 }
258
259 /// Convert the accumulated data into a final ContentBlock
260 fn finish(self) -> ContentBlock {
261 if self.is_thinking {
262 ContentBlock::Thinking {
263 thinking: self.thinking,
264 signature: if self.signature.is_empty() {
265 None
266 } else {
267 Some(self.signature)
268 },
269 }
270 } else if self.is_tool_use {
271 let input = serde_json::from_str(&self.partial_json)
272 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
273 ContentBlock::ToolUse {
274 id: self.tool_id,
275 name: self.tool_name,
276 input,
277 }
278 } else {
279 ContentBlock::Text {
280 text: self.text,
281 cache_control: None,
282 }
283 }
284 }
285}
286
287impl MessageAccumulator {
288 pub fn new() -> Self {
289 Self::default()
290 }
291
292 /// Process a single stream event, updating internal state
293 pub fn process_event(&mut self, event: &StreamEvent) {
294 match event {
295 StreamEvent::MessageStart { message } => {
296 self.response = Some(message.clone());
297 }
298 StreamEvent::ContentBlockStart { content_block, .. } => match content_block {
299 ContentBlock::Thinking { .. } => self.blocks.push(BlockBuilder::new_thinking()),
300 ContentBlock::ToolUse { id, name, .. } => {
301 self.blocks
302 .push(BlockBuilder::new_tool(id.clone(), name.clone()));
303 }
304 _ => self.blocks.push(BlockBuilder::new_text()),
305 },
306 StreamEvent::ContentBlockDelta { index, delta } => {
307 if let Some(block) = self.blocks.get_mut(*index) {
308 match delta {
309 ContentDelta::TextDelta { text } => block.text.push_str(text),
310 ContentDelta::ThinkingDelta { thinking } => {
311 block.thinking.push_str(thinking)
312 }
313 ContentDelta::InputJsonDelta { partial_json } => {
314 block.partial_json.push_str(partial_json);
315 }
316 ContentDelta::SignatureDelta { signature } => {
317 block.signature.push_str(signature);
318 }
319 }
320 }
321 }
322 StreamEvent::MessageDelta { delta, usage } => {
323 if let Some(ref mut resp) = self.response {
324 resp.stop_reason = delta.stop_reason.clone();
325 resp.stop_sequence = delta.stop_sequence.clone();
326 if let Some(u) = usage {
327 resp.usage.output_tokens = u.output_tokens;
328 }
329 }
330 }
331 // ContentBlockStop, MessageStop, Ping - no accumulation needed
332 _ => {}
333 }
334 }
335
336 /// Finalize and return the complete `MessagesResponse`
337 pub fn finish(mut self) -> Result<MessagesResponse, ClaudeApiError> {
338 let mut response = self
339 .response
340 .take()
341 .ok_or_else(|| ClaudeApiError::StreamError {
342 message: "Stream ended without message_start event".to_string(),
343 })?;
344
345 // Replace the skeleton content blocks with our accumulated ones
346 response.content = self.blocks.into_iter().map(|b| b.finish()).collect();
347
348 Ok(response)
349 }
350}