1use serde_json::Value;
2use std::collections::BTreeMap;
3use thiserror::Error;
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum IndexedStreamEvent {
7 TextDelta {
8 content_index: usize,
9 delta: String,
10 },
11 ThinkingDelta {
12 content_index: usize,
13 delta: String,
14 },
15 ToolCallStart {
16 content_index: usize,
17 id: String,
18 name: String,
19 },
20 ToolCallArgumentsDelta {
21 content_index: usize,
22 id: String,
23 delta: String,
24 },
25 ToolCallEnd {
26 content_index: usize,
27 id: String,
28 name: String,
29 arguments: Value,
30 metadata: Option<Value>,
31 },
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub enum OrderedContentPart {
36 Text(String),
37 Thinking(String),
38 ToolCall {
39 id: String,
40 name: String,
41 arguments: Value,
42 metadata: Option<Value>,
43 },
44}
45
46#[derive(Debug, Error)]
47pub enum StreamAssemblyError {
48 #[error("content index {content_index} changed slot type during streaming")]
49 ContentTypeMismatch { content_index: usize },
50
51 #[error("tool call id mismatch at content index {content_index}")]
52 ToolCallIdMismatch { content_index: usize },
53
54 #[error("invalid tool call arguments for {tool_call_id}: {source}")]
55 InvalidToolCallArguments {
56 tool_call_id: String,
57 #[source]
58 source: serde_json::Error,
59 },
60}
61
62impl PartialEq for StreamAssemblyError {
63 fn eq(&self, other: &Self) -> bool {
64 match (self, other) {
65 (
66 StreamAssemblyError::ContentTypeMismatch {
67 content_index: left,
68 },
69 StreamAssemblyError::ContentTypeMismatch {
70 content_index: right,
71 },
72 ) => left == right,
73 (
74 StreamAssemblyError::ToolCallIdMismatch {
75 content_index: left,
76 },
77 StreamAssemblyError::ToolCallIdMismatch {
78 content_index: right,
79 },
80 ) => left == right,
81 (
82 StreamAssemblyError::InvalidToolCallArguments {
83 tool_call_id: left, ..
84 },
85 StreamAssemblyError::InvalidToolCallArguments {
86 tool_call_id: right,
87 ..
88 },
89 ) => left == right,
90 _ => false,
91 }
92 }
93}
94
95#[derive(Debug, Clone, PartialEq)]
96enum ContentSlot {
97 Text(String),
98 Thinking(String),
99 ToolCall(ToolCallSlot),
100}
101
102#[derive(Debug, Clone, PartialEq)]
103struct ToolCallSlot {
104 id: String,
105 name: String,
106 arguments_buffer: String,
107 final_arguments: Option<Value>,
108 metadata: Option<Value>,
109}
110
111impl ToolCallSlot {
112 fn new(id: String, name: String) -> Self {
113 Self {
114 id,
115 name,
116 arguments_buffer: String::new(),
117 final_arguments: None,
118 metadata: None,
119 }
120 }
121
122 fn into_part(self) -> Result<OrderedContentPart, StreamAssemblyError> {
123 let arguments = if let Some(arguments) = self.final_arguments {
124 arguments
125 } else if self.arguments_buffer.trim().is_empty() {
126 Value::Object(Default::default())
127 } else {
128 serde_json::from_str(&self.arguments_buffer).map_err(|source| {
129 StreamAssemblyError::InvalidToolCallArguments {
130 tool_call_id: self.id.clone(),
131 source,
132 }
133 })?
134 };
135
136 Ok(OrderedContentPart::ToolCall {
137 id: self.id,
138 name: self.name,
139 arguments,
140 metadata: self.metadata,
141 })
142 }
143}
144
145pub fn assemble_ordered_content(
146 events: impl IntoIterator<Item = IndexedStreamEvent>,
147) -> Result<Vec<OrderedContentPart>, StreamAssemblyError> {
148 let mut slots: BTreeMap<usize, ContentSlot> = BTreeMap::new();
149
150 for event in events {
151 match event {
152 IndexedStreamEvent::TextDelta {
153 content_index,
154 delta,
155 } => match slots.get_mut(&content_index) {
156 Some(ContentSlot::Text(text)) => text.push_str(&delta),
157 Some(_) => {
158 return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
159 }
160 None => {
161 slots.insert(content_index, ContentSlot::Text(delta));
162 }
163 },
164 IndexedStreamEvent::ThinkingDelta {
165 content_index,
166 delta,
167 } => match slots.get_mut(&content_index) {
168 Some(ContentSlot::Thinking(text)) => text.push_str(&delta),
169 Some(_) => {
170 return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
171 }
172 None => {
173 slots.insert(content_index, ContentSlot::Thinking(delta));
174 }
175 },
176 IndexedStreamEvent::ToolCallStart {
177 content_index,
178 id,
179 name,
180 } => match slots.get_mut(&content_index) {
181 Some(ContentSlot::ToolCall(slot)) => {
182 if slot.id != id {
183 return Err(StreamAssemblyError::ToolCallIdMismatch { content_index });
184 }
185 if slot.name.is_empty() {
186 slot.name = name;
187 }
188 }
189 Some(_) => {
190 return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
191 }
192 None => {
193 slots.insert(
194 content_index,
195 ContentSlot::ToolCall(ToolCallSlot::new(id, name)),
196 );
197 }
198 },
199 IndexedStreamEvent::ToolCallArgumentsDelta {
200 content_index,
201 id,
202 delta,
203 } => match slots.get_mut(&content_index) {
204 Some(ContentSlot::ToolCall(slot)) => {
205 if slot.id != id {
206 return Err(StreamAssemblyError::ToolCallIdMismatch { content_index });
207 }
208 slot.arguments_buffer.push_str(&delta);
209 }
210 Some(_) => {
211 return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
212 }
213 None => {
214 let mut slot = ToolCallSlot::new(id, String::new());
215 slot.arguments_buffer.push_str(&delta);
216 slots.insert(content_index, ContentSlot::ToolCall(slot));
217 }
218 },
219 IndexedStreamEvent::ToolCallEnd {
220 content_index,
221 id,
222 name,
223 arguments,
224 metadata,
225 } => match slots.get_mut(&content_index) {
226 Some(ContentSlot::ToolCall(slot)) => {
227 if slot.id != id {
228 return Err(StreamAssemblyError::ToolCallIdMismatch { content_index });
229 }
230 if slot.name.is_empty() {
231 slot.name = name;
232 }
233 slot.final_arguments = Some(arguments);
234 slot.metadata = metadata;
235 }
236 Some(_) => {
237 return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
238 }
239 None => {
240 let mut slot = ToolCallSlot::new(id, name);
241 slot.final_arguments = Some(arguments);
242 slot.metadata = metadata;
243 slots.insert(content_index, ContentSlot::ToolCall(slot));
244 }
245 },
246 }
247 }
248
249 slots
250 .into_values()
251 .map(|slot| match slot {
252 ContentSlot::Text(text) => Ok(OrderedContentPart::Text(text)),
253 ContentSlot::Thinking(text) => Ok(OrderedContentPart::Thinking(text)),
254 ContentSlot::ToolCall(slot) => slot.into_part(),
255 })
256 .collect()
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use serde_json::json;
263
264 #[test]
265 fn preserves_index_order_for_out_of_order_events() {
266 let parts = assemble_ordered_content(vec![
267 IndexedStreamEvent::TextDelta {
268 content_index: 2,
269 delta: "third".to_string(),
270 },
271 IndexedStreamEvent::TextDelta {
272 content_index: 0,
273 delta: "first".to_string(),
274 },
275 IndexedStreamEvent::TextDelta {
276 content_index: 1,
277 delta: "second".to_string(),
278 },
279 ]);
280
281 assert_eq!(
282 parts,
283 Ok(vec![
284 OrderedContentPart::Text("first".to_string()),
285 OrderedContentPart::Text("second".to_string()),
286 OrderedContentPart::Text("third".to_string()),
287 ])
288 );
289 }
290
291 #[test]
292 fn preserves_text_tool_call_thinking_interleaving() {
293 let parts = assemble_ordered_content(vec![
294 IndexedStreamEvent::TextDelta {
295 content_index: 0,
296 delta: "check logs".to_string(),
297 },
298 IndexedStreamEvent::ToolCallStart {
299 content_index: 1,
300 id: "tc_1".to_string(),
301 name: "stakpak__run_command".to_string(),
302 },
303 IndexedStreamEvent::ToolCallArgumentsDelta {
304 content_index: 1,
305 id: "tc_1".to_string(),
306 delta: "{\"cmd\":\"kubectl get pods\"}".to_string(),
307 },
308 IndexedStreamEvent::ThinkingDelta {
309 content_index: 2,
310 delta: "observing cluster state".to_string(),
311 },
312 ]);
313
314 assert_eq!(
315 parts,
316 Ok(vec![
317 OrderedContentPart::Text("check logs".to_string()),
318 OrderedContentPart::ToolCall {
319 id: "tc_1".to_string(),
320 name: "stakpak__run_command".to_string(),
321 arguments: json!({"cmd":"kubectl get pods"}),
322 metadata: None,
323 },
324 OrderedContentPart::Thinking("observing cluster state".to_string()),
325 ])
326 );
327 }
328
329 #[test]
330 fn accepts_tool_call_end_without_start() {
331 let parts = assemble_ordered_content(vec![IndexedStreamEvent::ToolCallEnd {
332 content_index: 0,
333 id: "tc_1".to_string(),
334 name: "stakpak__view".to_string(),
335 arguments: json!({"path":"README.md"}),
336 metadata: Some(json!({"provider":"gemini"})),
337 }]);
338
339 assert_eq!(
340 parts,
341 Ok(vec![OrderedContentPart::ToolCall {
342 id: "tc_1".to_string(),
343 name: "stakpak__view".to_string(),
344 arguments: json!({"path":"README.md"}),
345 metadata: Some(json!({"provider":"gemini"})),
346 }])
347 );
348 }
349
350 #[test]
351 fn errors_on_content_type_mismatch_for_same_index() {
352 let result = assemble_ordered_content(vec![
353 IndexedStreamEvent::TextDelta {
354 content_index: 0,
355 delta: "hello".to_string(),
356 },
357 IndexedStreamEvent::ToolCallStart {
358 content_index: 0,
359 id: "tc_1".to_string(),
360 name: "stakpak__view".to_string(),
361 },
362 ]);
363
364 assert_eq!(
365 result,
366 Err(StreamAssemblyError::ContentTypeMismatch { content_index: 0 })
367 );
368 }
369
370 #[test]
371 fn errors_on_invalid_buffered_tool_arguments() {
372 let result = assemble_ordered_content(vec![
373 IndexedStreamEvent::ToolCallStart {
374 content_index: 0,
375 id: "tc_1".to_string(),
376 name: "stakpak__view".to_string(),
377 },
378 IndexedStreamEvent::ToolCallArgumentsDelta {
379 content_index: 0,
380 id: "tc_1".to_string(),
381 delta: "{not json".to_string(),
382 },
383 ]);
384
385 assert!(matches!(
386 result,
387 Err(StreamAssemblyError::InvalidToolCallArguments { tool_call_id, .. })
388 if tool_call_id == "tc_1"
389 ));
390 }
391}