Skip to main content

stakpak_agent_core/
stream.rs

1use serde_json::Value;
2use std::collections::BTreeMap;
3use thiserror::Error;
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum IndexedStreamEvent {
7    TextDelta {
8        content_index: usize,
9        delta: String,
10    },
11    ThinkingDelta {
12        content_index: usize,
13        delta: String,
14    },
15    ToolCallStart {
16        content_index: usize,
17        id: String,
18        name: String,
19    },
20    ToolCallArgumentsDelta {
21        content_index: usize,
22        id: String,
23        delta: String,
24    },
25    ToolCallEnd {
26        content_index: usize,
27        id: String,
28        name: String,
29        arguments: Value,
30        metadata: Option<Value>,
31    },
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub enum OrderedContentPart {
36    Text(String),
37    Thinking(String),
38    ToolCall {
39        id: String,
40        name: String,
41        arguments: Value,
42        metadata: Option<Value>,
43    },
44}
45
46#[derive(Debug, Error)]
47pub enum StreamAssemblyError {
48    #[error("content index {content_index} changed slot type during streaming")]
49    ContentTypeMismatch { content_index: usize },
50
51    #[error("tool call id mismatch at content index {content_index}")]
52    ToolCallIdMismatch { content_index: usize },
53
54    #[error("invalid tool call arguments for {tool_call_id}: {source}")]
55    InvalidToolCallArguments {
56        tool_call_id: String,
57        #[source]
58        source: serde_json::Error,
59    },
60}
61
62impl PartialEq for StreamAssemblyError {
63    fn eq(&self, other: &Self) -> bool {
64        match (self, other) {
65            (
66                StreamAssemblyError::ContentTypeMismatch {
67                    content_index: left,
68                },
69                StreamAssemblyError::ContentTypeMismatch {
70                    content_index: right,
71                },
72            ) => left == right,
73            (
74                StreamAssemblyError::ToolCallIdMismatch {
75                    content_index: left,
76                },
77                StreamAssemblyError::ToolCallIdMismatch {
78                    content_index: right,
79                },
80            ) => left == right,
81            (
82                StreamAssemblyError::InvalidToolCallArguments {
83                    tool_call_id: left, ..
84                },
85                StreamAssemblyError::InvalidToolCallArguments {
86                    tool_call_id: right,
87                    ..
88                },
89            ) => left == right,
90            _ => false,
91        }
92    }
93}
94
95#[derive(Debug, Clone, PartialEq)]
96enum ContentSlot {
97    Text(String),
98    Thinking(String),
99    ToolCall(ToolCallSlot),
100}
101
102#[derive(Debug, Clone, PartialEq)]
103struct ToolCallSlot {
104    id: String,
105    name: String,
106    arguments_buffer: String,
107    final_arguments: Option<Value>,
108    metadata: Option<Value>,
109}
110
111impl ToolCallSlot {
112    fn new(id: String, name: String) -> Self {
113        Self {
114            id,
115            name,
116            arguments_buffer: String::new(),
117            final_arguments: None,
118            metadata: None,
119        }
120    }
121
122    fn into_part(self) -> Result<OrderedContentPart, StreamAssemblyError> {
123        let arguments = if let Some(arguments) = self.final_arguments {
124            arguments
125        } else if self.arguments_buffer.trim().is_empty() {
126            Value::Object(Default::default())
127        } else {
128            serde_json::from_str(&self.arguments_buffer).map_err(|source| {
129                StreamAssemblyError::InvalidToolCallArguments {
130                    tool_call_id: self.id.clone(),
131                    source,
132                }
133            })?
134        };
135
136        Ok(OrderedContentPart::ToolCall {
137            id: self.id,
138            name: self.name,
139            arguments,
140            metadata: self.metadata,
141        })
142    }
143}
144
145pub fn assemble_ordered_content(
146    events: impl IntoIterator<Item = IndexedStreamEvent>,
147) -> Result<Vec<OrderedContentPart>, StreamAssemblyError> {
148    let mut slots: BTreeMap<usize, ContentSlot> = BTreeMap::new();
149
150    for event in events {
151        match event {
152            IndexedStreamEvent::TextDelta {
153                content_index,
154                delta,
155            } => match slots.get_mut(&content_index) {
156                Some(ContentSlot::Text(text)) => text.push_str(&delta),
157                Some(_) => {
158                    return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
159                }
160                None => {
161                    slots.insert(content_index, ContentSlot::Text(delta));
162                }
163            },
164            IndexedStreamEvent::ThinkingDelta {
165                content_index,
166                delta,
167            } => match slots.get_mut(&content_index) {
168                Some(ContentSlot::Thinking(text)) => text.push_str(&delta),
169                Some(_) => {
170                    return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
171                }
172                None => {
173                    slots.insert(content_index, ContentSlot::Thinking(delta));
174                }
175            },
176            IndexedStreamEvent::ToolCallStart {
177                content_index,
178                id,
179                name,
180            } => match slots.get_mut(&content_index) {
181                Some(ContentSlot::ToolCall(slot)) => {
182                    if slot.id != id {
183                        return Err(StreamAssemblyError::ToolCallIdMismatch { content_index });
184                    }
185                    if slot.name.is_empty() {
186                        slot.name = name;
187                    }
188                }
189                Some(_) => {
190                    return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
191                }
192                None => {
193                    slots.insert(
194                        content_index,
195                        ContentSlot::ToolCall(ToolCallSlot::new(id, name)),
196                    );
197                }
198            },
199            IndexedStreamEvent::ToolCallArgumentsDelta {
200                content_index,
201                id,
202                delta,
203            } => match slots.get_mut(&content_index) {
204                Some(ContentSlot::ToolCall(slot)) => {
205                    if slot.id != id {
206                        return Err(StreamAssemblyError::ToolCallIdMismatch { content_index });
207                    }
208                    slot.arguments_buffer.push_str(&delta);
209                }
210                Some(_) => {
211                    return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
212                }
213                None => {
214                    let mut slot = ToolCallSlot::new(id, String::new());
215                    slot.arguments_buffer.push_str(&delta);
216                    slots.insert(content_index, ContentSlot::ToolCall(slot));
217                }
218            },
219            IndexedStreamEvent::ToolCallEnd {
220                content_index,
221                id,
222                name,
223                arguments,
224                metadata,
225            } => match slots.get_mut(&content_index) {
226                Some(ContentSlot::ToolCall(slot)) => {
227                    if slot.id != id {
228                        return Err(StreamAssemblyError::ToolCallIdMismatch { content_index });
229                    }
230                    if slot.name.is_empty() {
231                        slot.name = name;
232                    }
233                    slot.final_arguments = Some(arguments);
234                    slot.metadata = metadata;
235                }
236                Some(_) => {
237                    return Err(StreamAssemblyError::ContentTypeMismatch { content_index });
238                }
239                None => {
240                    let mut slot = ToolCallSlot::new(id, name);
241                    slot.final_arguments = Some(arguments);
242                    slot.metadata = metadata;
243                    slots.insert(content_index, ContentSlot::ToolCall(slot));
244                }
245            },
246        }
247    }
248
249    slots
250        .into_values()
251        .map(|slot| match slot {
252            ContentSlot::Text(text) => Ok(OrderedContentPart::Text(text)),
253            ContentSlot::Thinking(text) => Ok(OrderedContentPart::Thinking(text)),
254            ContentSlot::ToolCall(slot) => slot.into_part(),
255        })
256        .collect()
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use serde_json::json;
263
264    #[test]
265    fn preserves_index_order_for_out_of_order_events() {
266        let parts = assemble_ordered_content(vec![
267            IndexedStreamEvent::TextDelta {
268                content_index: 2,
269                delta: "third".to_string(),
270            },
271            IndexedStreamEvent::TextDelta {
272                content_index: 0,
273                delta: "first".to_string(),
274            },
275            IndexedStreamEvent::TextDelta {
276                content_index: 1,
277                delta: "second".to_string(),
278            },
279        ]);
280
281        assert_eq!(
282            parts,
283            Ok(vec![
284                OrderedContentPart::Text("first".to_string()),
285                OrderedContentPart::Text("second".to_string()),
286                OrderedContentPart::Text("third".to_string()),
287            ])
288        );
289    }
290
291    #[test]
292    fn preserves_text_tool_call_thinking_interleaving() {
293        let parts = assemble_ordered_content(vec![
294            IndexedStreamEvent::TextDelta {
295                content_index: 0,
296                delta: "check logs".to_string(),
297            },
298            IndexedStreamEvent::ToolCallStart {
299                content_index: 1,
300                id: "tc_1".to_string(),
301                name: "stakpak__run_command".to_string(),
302            },
303            IndexedStreamEvent::ToolCallArgumentsDelta {
304                content_index: 1,
305                id: "tc_1".to_string(),
306                delta: "{\"cmd\":\"kubectl get pods\"}".to_string(),
307            },
308            IndexedStreamEvent::ThinkingDelta {
309                content_index: 2,
310                delta: "observing cluster state".to_string(),
311            },
312        ]);
313
314        assert_eq!(
315            parts,
316            Ok(vec![
317                OrderedContentPart::Text("check logs".to_string()),
318                OrderedContentPart::ToolCall {
319                    id: "tc_1".to_string(),
320                    name: "stakpak__run_command".to_string(),
321                    arguments: json!({"cmd":"kubectl get pods"}),
322                    metadata: None,
323                },
324                OrderedContentPart::Thinking("observing cluster state".to_string()),
325            ])
326        );
327    }
328
329    #[test]
330    fn accepts_tool_call_end_without_start() {
331        let parts = assemble_ordered_content(vec![IndexedStreamEvent::ToolCallEnd {
332            content_index: 0,
333            id: "tc_1".to_string(),
334            name: "stakpak__view".to_string(),
335            arguments: json!({"path":"README.md"}),
336            metadata: Some(json!({"provider":"gemini"})),
337        }]);
338
339        assert_eq!(
340            parts,
341            Ok(vec![OrderedContentPart::ToolCall {
342                id: "tc_1".to_string(),
343                name: "stakpak__view".to_string(),
344                arguments: json!({"path":"README.md"}),
345                metadata: Some(json!({"provider":"gemini"})),
346            }])
347        );
348    }
349
350    #[test]
351    fn errors_on_content_type_mismatch_for_same_index() {
352        let result = assemble_ordered_content(vec![
353            IndexedStreamEvent::TextDelta {
354                content_index: 0,
355                delta: "hello".to_string(),
356            },
357            IndexedStreamEvent::ToolCallStart {
358                content_index: 0,
359                id: "tc_1".to_string(),
360                name: "stakpak__view".to_string(),
361            },
362        ]);
363
364        assert_eq!(
365            result,
366            Err(StreamAssemblyError::ContentTypeMismatch { content_index: 0 })
367        );
368    }
369
370    #[test]
371    fn errors_on_invalid_buffered_tool_arguments() {
372        let result = assemble_ordered_content(vec![
373            IndexedStreamEvent::ToolCallStart {
374                content_index: 0,
375                id: "tc_1".to_string(),
376                name: "stakpak__view".to_string(),
377            },
378            IndexedStreamEvent::ToolCallArgumentsDelta {
379                content_index: 0,
380                id: "tc_1".to_string(),
381                delta: "{not json".to_string(),
382            },
383        ]);
384
385        assert!(matches!(
386            result,
387            Err(StreamAssemblyError::InvalidToolCallArguments { tool_call_id, .. })
388                if tool_call_id == "tc_1"
389        ));
390    }
391}