Skip to main content

vtcode_core/llm/providers/shared/
responses_stream.rs

1use crate::llm::error_display;
2use crate::llm::provider::{LLMError, LLMNormalizedStream, LLMResponse, NormalizedStreamEvent};
3use crate::llm::providers::shared::{extract_data_payload, find_sse_boundary};
4use async_stream::try_stream;
5use futures::StreamExt;
6use hashbrown::{HashMap, HashSet};
7use serde_json::{Value, json};
8
9use super::StreamAggregator;
10
11pub struct ResponsesNormalizedStreamOptions {
12    pub provider_name: &'static str,
13    pub model: String,
14    pub emit_reasoning: bool,
15}
16
17struct ResponsesNormalizedStreamProcessor<P> {
18    options: ResponsesNormalizedStreamOptions,
19    parse_final_response: P,
20    aggregator: StreamAggregator,
21    seen_tool_calls: HashSet<String>,
22    tool_call_indexes: HashMap<String, usize>,
23    tool_call_names: HashMap<String, String>,
24    next_tool_call_index: usize,
25    final_response: Option<Value>,
26    done: bool,
27}
28
29impl<P> ResponsesNormalizedStreamProcessor<P>
30where
31    P: Fn(Value) -> Result<LLMResponse, LLMError>,
32{
33    fn new(options: ResponsesNormalizedStreamOptions, parse_final_response: P) -> Self {
34        Self {
35            aggregator: StreamAggregator::new(options.model.clone()),
36            options,
37            parse_final_response,
38            seen_tool_calls: HashSet::new(),
39            tool_call_indexes: HashMap::new(),
40            tool_call_names: HashMap::new(),
41            next_tool_call_index: 0,
42            final_response: None,
43            done: false,
44        }
45    }
46
47    fn is_done(&self) -> bool {
48        self.done
49    }
50
51    fn handle_payload(&mut self, payload: Value) -> Result<Vec<NormalizedStreamEvent>, LLMError> {
52        let mut events = Vec::new();
53
54        if let Some(usage) = payload.get("usage").cloned()
55            && let Ok(usage) = serde_json::from_value(usage)
56        {
57            self.aggregator.set_usage(usage);
58        }
59
60        let event_type = payload.get("type").and_then(Value::as_str).unwrap_or("");
61        match event_type {
62            "response.output_text.delta" => {
63                let delta = payload
64                    .get("delta")
65                    .and_then(Value::as_str)
66                    .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
67                for event in self.aggregator.handle_content(delta) {
68                    match event {
69                        crate::llm::provider::LLMStreamEvent::Token { delta } => {
70                            events.push(NormalizedStreamEvent::TextDelta { delta });
71                        }
72                        crate::llm::provider::LLMStreamEvent::Reasoning { delta }
73                            if self.options.emit_reasoning =>
74                        {
75                            events.push(NormalizedStreamEvent::ReasoningDelta { delta });
76                        }
77                        _ => {}
78                    }
79                }
80            }
81            "response.refusal.delta" => {
82                let delta = payload
83                    .get("delta")
84                    .and_then(Value::as_str)
85                    .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
86                if !delta.is_empty() {
87                    self.aggregator.content.push_str(delta);
88                    events.push(NormalizedStreamEvent::TextDelta {
89                        delta: delta.to_string(),
90                    });
91                }
92            }
93            "response.reasoning_text.delta"
94            | "response.reasoning_summary_text.delta"
95            | "response.reasoning_content.delta" => {
96                if self.options.emit_reasoning
97                    && let Some(delta) = payload.get("delta").and_then(Value::as_str)
98                    && let Some(delta) = self.aggregator.handle_reasoning(delta)
99                {
100                    events.push(NormalizedStreamEvent::ReasoningDelta { delta });
101                }
102            }
103            "response.output_item.added" | "response.output_item.done" => {
104                if let Some(item) = payload.get("item") {
105                    let tool_call = self.capture_tool_call_metadata(
106                        item,
107                        payload
108                            .get("output_index")
109                            .and_then(Value::as_u64)
110                            .map(|value| value as usize),
111                    );
112                    if let Some((call_id, name)) = tool_call {
113                        self.push_tool_call_start(&mut events, call_id, Some(name));
114                    }
115                }
116            }
117            "response.function_call_arguments.delta" => {
118                let delta = payload
119                    .get("delta")
120                    .and_then(Value::as_str)
121                    .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
122                let call_id = payload
123                    .get("item_id")
124                    .and_then(Value::as_str)
125                    .or_else(|| payload.get("call_id").and_then(Value::as_str))
126                    .filter(|value| !value.is_empty())
127                    .map(ToOwned::to_owned)
128                    .unwrap_or_else(|| format!("tool_call_{}", self.next_tool_call_index));
129                let index = self.resolve_tool_call_index(
130                    &call_id,
131                    payload
132                        .get("output_index")
133                        .and_then(Value::as_u64)
134                        .map(|value| value as usize),
135                );
136
137                let name = self.tool_call_names.get(&call_id).cloned();
138                self.push_tool_call_start(&mut events, call_id.clone(), name);
139
140                if !delta.is_empty() {
141                    self.aggregator.handle_tool_calls(&[json!({
142                        "index": index,
143                        "id": call_id,
144                        "function": {
145                            "arguments": delta,
146                        }
147                    })]);
148                    events.push(NormalizedStreamEvent::ToolCallDelta {
149                        call_id,
150                        delta: delta.to_string(),
151                    });
152                }
153            }
154            "response.completed" => {
155                if let Some(response) = payload.get("response") {
156                    self.final_response = Some(response.clone());
157                }
158                self.done = true;
159            }
160            "response.failed" | "response.incomplete" | "error" => {
161                let message = extract_error_message(&payload)
162                    .unwrap_or_else(|| "unknown error from responses stream".to_string());
163                return Err(provider_error(self.options.provider_name, message));
164            }
165            _ => {}
166        }
167
168        Ok(events)
169    }
170
171    fn finish(self) -> Result<Vec<NormalizedStreamEvent>, LLMError> {
172        let streamed = self.aggregator.finalize();
173        let mut response = if let Some(final_response) = self.final_response {
174            (self.parse_final_response)(final_response)?
175        } else {
176            streamed.clone()
177        };
178
179        merge_streamed_response(&mut response, streamed);
180
181        let mut events = Vec::new();
182        if let Some(usage) = response.usage.clone() {
183            events.push(NormalizedStreamEvent::Usage { usage });
184        }
185        events.push(NormalizedStreamEvent::Done {
186            response: Box::new(response),
187        });
188        Ok(events)
189    }
190
191    fn capture_tool_call_metadata(
192        &mut self,
193        item: &Value,
194        output_index: Option<usize>,
195    ) -> Option<(String, String)> {
196        let item_type = item.get("type").and_then(Value::as_str).unwrap_or("");
197        if item_type != "function_call" {
198            return None;
199        }
200
201        let call_id = item
202            .get("id")
203            .and_then(Value::as_str)
204            .or_else(|| item.get("call_id").and_then(Value::as_str))
205            .filter(|value| !value.is_empty());
206        let name = item.get("name").and_then(Value::as_str).or_else(|| {
207            item.get("function")
208                .and_then(|function| function.get("name"))
209                .and_then(Value::as_str)
210        });
211        if let (Some(call_id), Some(name)) = (call_id, name) {
212            self.tool_call_names
213                .entry(call_id.to_string())
214                .or_insert_with(|| name.to_string());
215            let index = self.resolve_tool_call_index(call_id, output_index);
216            self.aggregator.handle_tool_calls(&[json!({
217                "index": index,
218                "id": call_id,
219                "function": {
220                    "name": name,
221                }
222            })]);
223            return Some((call_id.to_string(), name.to_string()));
224        }
225
226        None
227    }
228
229    fn push_tool_call_start(
230        &mut self,
231        events: &mut Vec<NormalizedStreamEvent>,
232        call_id: String,
233        name: Option<String>,
234    ) {
235        if self.seen_tool_calls.insert(call_id.clone()) {
236            events.push(NormalizedStreamEvent::ToolCallStart { call_id, name });
237        }
238    }
239
240    fn resolve_tool_call_index(&mut self, call_id: &str, output_index: Option<usize>) -> usize {
241        if let Some(index) = output_index {
242            self.tool_call_indexes.insert(call_id.to_string(), index);
243            self.next_tool_call_index = self.next_tool_call_index.max(index + 1);
244            return index;
245        }
246
247        if let Some(index) = self.tool_call_indexes.get(call_id).copied() {
248            return index;
249        }
250
251        let index = self.next_tool_call_index;
252        self.tool_call_indexes.insert(call_id.to_string(), index);
253        self.next_tool_call_index += 1;
254        index
255    }
256}
257
258pub fn create_responses_normalized_stream<P>(
259    response: reqwest::Response,
260    options: ResponsesNormalizedStreamOptions,
261    parse_final_response: P,
262) -> LLMNormalizedStream
263where
264    P: Fn(Value) -> Result<LLMResponse, LLMError> + Send + 'static,
265{
266    let stream = try_stream! {
267        let provider_name = options.provider_name;
268        let mut processor = ResponsesNormalizedStreamProcessor::new(options, parse_final_response);
269        let mut body_stream = response.bytes_stream();
270        let mut buffer = String::new();
271
272        while let Some(chunk_result) = body_stream.next().await {
273            let chunk = chunk_result.map_err(|err| provider_error(
274                provider_name,
275                format!("streaming error: {err}"),
276            ))?;
277            buffer.push_str(&String::from_utf8_lossy(&chunk));
278
279            while let Some((split_idx, delimiter_len)) = find_sse_boundary(&buffer) {
280                let event = buffer[..split_idx].to_string();
281                buffer.drain(..split_idx + delimiter_len);
282
283                if let Some(data_payload) = extract_data_payload(&event) {
284                    let trimmed_payload = data_payload.trim();
285                    if trimmed_payload.is_empty() || trimmed_payload == "[DONE]" {
286                        continue;
287                    }
288
289                    let payload: Value = serde_json::from_str(trimmed_payload).map_err(|err| {
290                        provider_error(provider_name, format!("invalid stream payload: {err}"))
291                    })?;
292
293                    for event in processor.handle_payload(payload)? {
294                        yield event;
295                    }
296
297                    if processor.is_done() {
298                        break;
299                    }
300                }
301            }
302
303            if processor.is_done() {
304                break;
305            }
306        }
307
308        for event in processor.finish()? {
309            yield event;
310        }
311    };
312
313    Box::pin(stream)
314}
315
316fn merge_streamed_response(response: &mut LLMResponse, streamed: LLMResponse) {
317    if response.content.as_deref().unwrap_or_default().is_empty() {
318        response.content = streamed.content;
319    } else if let (Some(content), Some(streamed_content)) =
320        (&mut response.content, streamed.content)
321        && !streamed_content.is_empty()
322        && !content.contains(&streamed_content)
323    {
324        content.push_str(&streamed_content);
325    }
326
327    if response.tool_calls.is_none() {
328        response.tool_calls = streamed.tool_calls;
329    }
330
331    if response.usage.is_none() {
332        response.usage = streamed.usage;
333    }
334
335    if response.reasoning.is_none() {
336        response.reasoning = streamed.reasoning;
337    }
338
339    if response.reasoning_details.is_none() {
340        response.reasoning_details = streamed.reasoning_details;
341    }
342
343    if response.tool_references.is_empty() && !streamed.tool_references.is_empty() {
344        response.tool_references = streamed.tool_references;
345    }
346
347    if response.request_id.is_none() {
348        response.request_id = streamed.request_id;
349    }
350
351    if response.organization_id.is_none() {
352        response.organization_id = streamed.organization_id;
353    }
354}
355
356fn extract_error_message(payload: &Value) -> Option<String> {
357    payload
358        .get("error")
359        .and_then(|error| error.get("message"))
360        .and_then(Value::as_str)
361        .map(ToOwned::to_owned)
362        .or_else(|| {
363            payload
364                .get("response")
365                .and_then(|response| response.get("error"))
366                .and_then(|error| error.get("message"))
367                .and_then(Value::as_str)
368                .map(ToOwned::to_owned)
369        })
370}
371
372fn provider_error(provider_name: &str, message: impl Into<String>) -> LLMError {
373    let message = error_display::format_llm_error(provider_name, &message.into());
374    LLMError::Provider {
375        message,
376        metadata: None,
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::{
383        ResponsesNormalizedStreamOptions, ResponsesNormalizedStreamProcessor, provider_error,
384    };
385    use crate::llm::provider::{FinishReason, LLMResponse, NormalizedStreamEvent, ToolCall};
386    use serde_json::{Value, json};
387
388    fn options() -> ResponsesNormalizedStreamOptions {
389        ResponsesNormalizedStreamOptions {
390            provider_name: "TestProvider",
391            model: "gpt-5".to_string(),
392            emit_reasoning: true,
393        }
394    }
395
396    fn parse_response(value: Value) -> Result<LLMResponse, crate::llm::provider::LLMError> {
397        let content = value
398            .get("output")
399            .and_then(Value::as_array)
400            .and_then(|items| items.first())
401            .and_then(|item| item.get("content"))
402            .and_then(Value::as_array)
403            .and_then(|content| content.first())
404            .and_then(|item| item.get("text"))
405            .and_then(Value::as_str)
406            .map(ToOwned::to_owned);
407
408        Ok(LLMResponse {
409            content,
410            model: "gpt-5".to_string(),
411            finish_reason: FinishReason::Stop,
412            ..Default::default()
413        })
414    }
415
416    #[test]
417    fn text_delta_and_completed_yield_text_then_done() {
418        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
419
420        let events = processor
421            .handle_payload(json!({
422                "type": "response.output_text.delta",
423                "delta": "hello"
424            }))
425            .expect("text delta should parse");
426        assert!(matches!(
427            events.as_slice(),
428            [NormalizedStreamEvent::TextDelta { delta }] if delta == "hello"
429        ));
430
431        let completed_events = processor
432            .handle_payload(json!({
433                "type": "response.completed",
434                "response": {
435                    "output": [{
436                        "type": "message",
437                        "content": [{"type": "output_text", "text": "hello"}]
438                    }]
439                }
440            }))
441            .expect("completed event should parse");
442        assert!(completed_events.is_empty());
443
444        let finished = processor.finish().expect("finish should succeed");
445        assert!(matches!(
446            finished.as_slice(),
447            [NormalizedStreamEvent::Done { response }]
448                if response.content.as_deref() == Some("hello")
449        ));
450    }
451
452    #[test]
453    fn tool_call_deltas_emit_start_and_finish_with_assembled_tool_call() {
454        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), |_| {
455            Ok(LLMResponse {
456                model: "gpt-5".to_string(),
457                ..Default::default()
458            })
459        });
460
461        let started = processor
462            .handle_payload(json!({
463                "type": "response.output_item.added",
464                "output_index": 0,
465                "item": {
466                    "type": "function_call",
467                    "id": "call_1",
468                    "name": "search_workspace"
469                }
470            }))
471            .expect("output item metadata should parse");
472        assert!(matches!(
473            started.as_slice(),
474            [NormalizedStreamEvent::ToolCallStart { call_id, name }]
475                if call_id == "call_1" && name.as_deref() == Some("search_workspace")
476        ));
477
478        let first = processor
479            .handle_payload(json!({
480                "type": "response.function_call_arguments.delta",
481                "item_id": "call_1",
482                "delta": "{\"query\":\"vt"
483            }))
484            .expect("first tool delta should parse");
485        assert!(matches!(
486            first.as_slice(),
487            [NormalizedStreamEvent::ToolCallDelta { call_id: delta_call_id, delta }]
488            if delta_call_id == "call_1"
489                && delta == "{\"query\":\"vt"
490        ));
491
492        let second = processor
493            .handle_payload(json!({
494                "type": "response.function_call_arguments.delta",
495                "item_id": "call_1",
496                "delta": "code\"}"
497            }))
498            .expect("second tool delta should parse");
499        assert!(matches!(
500            second.as_slice(),
501            [NormalizedStreamEvent::ToolCallDelta { call_id, delta }]
502                if call_id == "call_1" && delta == "code\"}"
503        ));
504
505        let finished = processor.finish().expect("finish should succeed");
506        let response = match finished.as_slice() {
507            [NormalizedStreamEvent::Done { response }] => response,
508            _ => panic!("expected done event"),
509        };
510        let tool_calls = response
511            .tool_calls
512            .as_ref()
513            .expect("tool call should be assembled");
514        assert_eq!(
515            tool_calls,
516            &vec![ToolCall::function(
517                "call_1".to_string(),
518                "search_workspace".to_string(),
519                "{\"query\":\"vtcode\"}".to_string(),
520            )]
521        );
522    }
523
524    #[test]
525    fn refusal_delta_streams_visible_output() {
526        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
527
528        let events = processor
529            .handle_payload(json!({
530                "type": "response.refusal.delta",
531                "delta": "I can't help with that"
532            }))
533            .expect("refusal delta should parse");
534        assert!(matches!(
535            events.as_slice(),
536            [NormalizedStreamEvent::TextDelta { delta }]
537                if delta == "I can't help with that"
538        ));
539
540        let finished = processor.finish().expect("finish should succeed");
541        assert!(matches!(
542            finished.as_slice(),
543            [NormalizedStreamEvent::Done { response }]
544                if response.content.as_deref() == Some("I can't help with that")
545        ));
546    }
547
548    #[test]
549    fn failed_incomplete_and_error_events_surface_backend_message() {
550        for payload in [
551            json!({"type": "response.failed", "response": {"error": {"message": "failed"}}}),
552            json!({"type": "response.incomplete", "response": {"error": {"message": "incomplete"}}}),
553            json!({"type": "error", "error": {"message": "errored"}}),
554        ] {
555            let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
556            let error = processor
557                .handle_payload(payload)
558                .expect_err("error payload should fail");
559            assert!(
560                error.to_string().contains("failed")
561                    || error.to_string().contains("incomplete")
562                    || error.to_string().contains("errored")
563            );
564        }
565    }
566
567    #[test]
568    fn unknown_documented_events_are_ignored() {
569        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
570        let events = processor
571            .handle_payload(json!({
572                "type": "response.file_search_call.searching",
573                "query": "needle"
574            }))
575            .expect("unknown documented event should be ignored");
576        assert!(events.is_empty());
577        processor
578            .handle_payload(json!({
579                "type": "response.code_interpreter_call.code.delta",
580                "delta": "print(1)"
581            }))
582            .expect("code interpreter event should be ignored");
583
584        let finished = processor.finish().expect("finish should succeed");
585        assert!(matches!(
586            finished.as_slice(),
587            [NormalizedStreamEvent::Done { .. }]
588        ));
589    }
590
591    #[test]
592    fn missing_delta_reports_provider_error() {
593        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
594        let error = processor
595            .handle_payload(json!({"type": "response.output_text.delta"}))
596            .expect_err("missing delta should fail");
597        assert_eq!(
598            error.to_string(),
599            provider_error("TestProvider", "missing delta").to_string()
600        );
601    }
602}