Skip to main content

vtcode_core/llm/providers/shared/
responses_stream.rs

1use crate::llm::error_display;
2use crate::llm::provider::{LLMError, LLMNormalizedStream, LLMResponse, NormalizedStreamEvent};
3use crate::llm::providers::shared::{extract_data_payload, find_sse_boundary};
4use async_stream::try_stream;
5use futures::StreamExt;
6use hashbrown::{HashMap, HashSet};
7use serde_json::{Value, json};
8
9use super::StreamAggregator;
10
11pub struct ResponsesNormalizedStreamOptions {
12    pub provider_name: &'static str,
13    pub model: String,
14    pub emit_reasoning: bool,
15}
16
17struct ResponsesNormalizedStreamProcessor<P> {
18    options: ResponsesNormalizedStreamOptions,
19    parse_final_response: P,
20    aggregator: StreamAggregator,
21    seen_tool_calls: HashSet<String>,
22    tool_call_indexes: HashMap<String, usize>,
23    tool_call_names: HashMap<String, String>,
24    next_tool_call_index: usize,
25    final_response: Option<Value>,
26    done: bool,
27}
28
29impl<P> ResponsesNormalizedStreamProcessor<P>
30where
31    P: Fn(Value) -> Result<LLMResponse, LLMError>,
32{
33    fn new(options: ResponsesNormalizedStreamOptions, parse_final_response: P) -> Self {
34        Self {
35            aggregator: StreamAggregator::new(options.model.clone()),
36            options,
37            parse_final_response,
38            seen_tool_calls: HashSet::new(),
39            tool_call_indexes: HashMap::new(),
40            tool_call_names: HashMap::new(),
41            next_tool_call_index: 0,
42            final_response: None,
43            done: false,
44        }
45    }
46
47    fn is_done(&self) -> bool {
48        self.done
49    }
50
51    fn handle_payload(&mut self, payload: Value) -> Result<Vec<NormalizedStreamEvent>, LLMError> {
52        let mut events = Vec::new();
53
54        if let Some(usage) = payload.get("usage").cloned()
55            && let Ok(usage) = serde_json::from_value(usage)
56        {
57            self.aggregator.set_usage(usage);
58        }
59
60        let event_type = payload.get("type").and_then(Value::as_str).unwrap_or("");
61        match event_type {
62            "response.output_text.delta" => {
63                let delta = payload
64                    .get("delta")
65                    .and_then(Value::as_str)
66                    .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
67                for event in self.aggregator.handle_content(delta) {
68                    match event {
69                        crate::llm::provider::LLMStreamEvent::Token { delta } => {
70                            events.push(NormalizedStreamEvent::TextDelta { delta });
71                        }
72                        crate::llm::provider::LLMStreamEvent::Reasoning { delta }
73                            if self.options.emit_reasoning =>
74                        {
75                            events.push(NormalizedStreamEvent::ReasoningDelta { delta });
76                        }
77                        _ => {}
78                    }
79                }
80            }
81            "response.refusal.delta" => {
82                let delta = payload
83                    .get("delta")
84                    .and_then(Value::as_str)
85                    .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
86                if !delta.is_empty() {
87                    self.aggregator.content.push_str(delta);
88                    events.push(NormalizedStreamEvent::TextDelta {
89                        delta: delta.to_string(),
90                    });
91                }
92            }
93            "response.reasoning_text.delta"
94            | "response.reasoning_summary_text.delta"
95            | "response.reasoning_content.delta" => {
96                if self.options.emit_reasoning
97                    && let Some(delta) = payload.get("delta").and_then(Value::as_str)
98                    && let Some(delta) = self.aggregator.handle_reasoning(delta)
99                {
100                    events.push(NormalizedStreamEvent::ReasoningDelta { delta });
101                }
102            }
103            "response.output_item.added" | "response.output_item.done" => {
104                if let Some(item) = payload.get("item") {
105                    let tool_call = self.capture_tool_call_metadata(
106                        item,
107                        payload
108                            .get("output_index")
109                            .and_then(Value::as_u64)
110                            .map(|value| value as usize),
111                    );
112                    if let Some((call_id, name)) = tool_call {
113                        self.push_tool_call_start(&mut events, call_id, Some(name));
114                    }
115                }
116            }
117            "response.function_call_arguments.delta" => {
118                let delta = payload
119                    .get("delta")
120                    .and_then(Value::as_str)
121                    .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
122                let call_id = payload
123                    .get("item_id")
124                    .and_then(Value::as_str)
125                    .or_else(|| payload.get("call_id").and_then(Value::as_str))
126                    .filter(|value| !value.is_empty())
127                    .map(ToOwned::to_owned)
128                    .unwrap_or_else(|| format!("tool_call_{}", self.next_tool_call_index));
129                let index = self.resolve_tool_call_index(
130                    &call_id,
131                    payload
132                        .get("output_index")
133                        .and_then(Value::as_u64)
134                        .map(|value| value as usize),
135                );
136
137                let name = self.tool_call_names.get(&call_id).cloned();
138                self.push_tool_call_start(&mut events, call_id.clone(), name);
139
140                if !delta.is_empty() {
141                    self.aggregator.handle_tool_calls(&[json!({
142                        "index": index,
143                        "id": call_id,
144                        "function": {
145                            "arguments": delta,
146                        }
147                    })]);
148                    events.push(NormalizedStreamEvent::ToolCallDelta {
149                        call_id,
150                        delta: delta.to_string(),
151                    });
152                }
153            }
154            "response.completed" => {
155                if let Some(response) = payload.get("response") {
156                    self.final_response = Some(response.clone());
157                }
158                self.done = true;
159            }
160            "response.failed" | "response.incomplete" | "error" => {
161                let message = extract_error_message(&payload)
162                    .unwrap_or_else(|| "unknown error from responses stream".to_string());
163                return Err(provider_error(self.options.provider_name, message));
164            }
165            _ => {}
166        }
167
168        Ok(events)
169    }
170
171    fn finish(self) -> Result<Vec<NormalizedStreamEvent>, LLMError> {
172        let streamed = self.aggregator.finalize();
173        let mut response = if let Some(final_response) = self.final_response {
174            match (self.parse_final_response)(final_response.clone()) {
175                Ok(response) => response,
176                Err(_)
177                    if final_response_output_is_empty(&final_response)
178                        && streamed_response_is_usable(&streamed) =>
179                {
180                    let mut response = streamed.clone();
181                    merge_final_response_metadata(&mut response, &final_response);
182                    response
183                }
184                Err(err) => return Err(err),
185            }
186        } else {
187            streamed.clone()
188        };
189
190        merge_streamed_response(&mut response, streamed);
191
192        let mut events = Vec::new();
193        if let Some(usage) = response.usage.clone() {
194            events.push(NormalizedStreamEvent::Usage { usage });
195        }
196        events.push(NormalizedStreamEvent::Done {
197            response: Box::new(response),
198        });
199        Ok(events)
200    }
201
202    fn capture_tool_call_metadata(
203        &mut self,
204        item: &Value,
205        output_index: Option<usize>,
206    ) -> Option<(String, String)> {
207        let item_type = item.get("type").and_then(Value::as_str).unwrap_or("");
208        if item_type != "function_call" {
209            return None;
210        }
211
212        let call_id = item
213            .get("id")
214            .and_then(Value::as_str)
215            .or_else(|| item.get("call_id").and_then(Value::as_str))
216            .filter(|value| !value.is_empty());
217        let name = item.get("name").and_then(Value::as_str).or_else(|| {
218            item.get("function")
219                .and_then(|function| function.get("name"))
220                .and_then(Value::as_str)
221        });
222        if let (Some(call_id), Some(name)) = (call_id, name) {
223            self.tool_call_names
224                .entry(call_id.to_string())
225                .or_insert_with(|| name.to_string());
226            let index = self.resolve_tool_call_index(call_id, output_index);
227            self.aggregator.handle_tool_calls(&[json!({
228                "index": index,
229                "id": call_id,
230                "function": {
231                    "name": name,
232                }
233            })]);
234            return Some((call_id.to_string(), name.to_string()));
235        }
236
237        None
238    }
239
240    fn push_tool_call_start(
241        &mut self,
242        events: &mut Vec<NormalizedStreamEvent>,
243        call_id: String,
244        name: Option<String>,
245    ) {
246        if self.seen_tool_calls.insert(call_id.clone()) {
247            events.push(NormalizedStreamEvent::ToolCallStart { call_id, name });
248        }
249    }
250
251    fn resolve_tool_call_index(&mut self, call_id: &str, output_index: Option<usize>) -> usize {
252        if let Some(index) = output_index {
253            self.tool_call_indexes.insert(call_id.to_string(), index);
254            self.next_tool_call_index = self.next_tool_call_index.max(index + 1);
255            return index;
256        }
257
258        if let Some(index) = self.tool_call_indexes.get(call_id).copied() {
259            return index;
260        }
261
262        let index = self.next_tool_call_index;
263        self.tool_call_indexes.insert(call_id.to_string(), index);
264        self.next_tool_call_index += 1;
265        index
266    }
267}
268
269pub fn create_responses_normalized_stream<P>(
270    response: reqwest::Response,
271    options: ResponsesNormalizedStreamOptions,
272    parse_final_response: P,
273) -> LLMNormalizedStream
274where
275    P: Fn(Value) -> Result<LLMResponse, LLMError> + Send + 'static,
276{
277    let stream = try_stream! {
278        let provider_name = options.provider_name;
279        let mut processor = ResponsesNormalizedStreamProcessor::new(options, parse_final_response);
280        let mut body_stream = response.bytes_stream();
281        let mut buffer = String::new();
282
283        while let Some(chunk_result) = body_stream.next().await {
284            let chunk = chunk_result.map_err(|err| provider_error(
285                provider_name,
286                format!("streaming error: {err}"),
287            ))?;
288            buffer.push_str(&String::from_utf8_lossy(&chunk));
289
290            while let Some((split_idx, delimiter_len)) = find_sse_boundary(&buffer) {
291                let event = buffer[..split_idx].to_string();
292                buffer.drain(..split_idx + delimiter_len);
293
294                if let Some(data_payload) = extract_data_payload(&event) {
295                    let trimmed_payload = data_payload.trim();
296                    if trimmed_payload.is_empty() || trimmed_payload == "[DONE]" {
297                        continue;
298                    }
299
300                    let payload: Value = serde_json::from_str(trimmed_payload).map_err(|err| {
301                        provider_error(provider_name, format!("invalid stream payload: {err}"))
302                    })?;
303
304                    for event in processor.handle_payload(payload)? {
305                        yield event;
306                    }
307
308                    if processor.is_done() {
309                        break;
310                    }
311                }
312            }
313
314            if processor.is_done() {
315                break;
316            }
317        }
318
319        for event in processor.finish()? {
320            yield event;
321        }
322    };
323
324    Box::pin(stream)
325}
326
327fn streamed_response_is_usable(response: &LLMResponse) -> bool {
328    response
329        .content
330        .as_deref()
331        .is_some_and(|content| !content.is_empty())
332        || response
333            .tool_calls
334            .as_ref()
335            .is_some_and(|tool_calls| !tool_calls.is_empty())
336        || response
337            .reasoning
338            .as_deref()
339            .is_some_and(|reasoning| !reasoning.is_empty())
340        || response
341            .reasoning_details
342            .as_ref()
343            .is_some_and(|details| !details.is_empty())
344}
345
346fn final_response_output_is_empty(final_response: &Value) -> bool {
347    final_response
348        .get("output")
349        .and_then(Value::as_array)
350        .is_some_and(Vec::is_empty)
351}
352
353fn merge_streamed_response(response: &mut LLMResponse, streamed: LLMResponse) {
354    if response.content.as_deref().unwrap_or_default().is_empty() {
355        response.content = streamed.content;
356    } else if let (Some(content), Some(streamed_content)) =
357        (&mut response.content, streamed.content)
358        && !streamed_content.is_empty()
359        && !content.contains(&streamed_content)
360    {
361        content.push_str(&streamed_content);
362    }
363
364    if response.tool_calls.is_none() {
365        response.tool_calls = streamed.tool_calls;
366    }
367
368    if response.usage.is_none() {
369        response.usage = streamed.usage;
370    }
371
372    if response.reasoning.is_none() {
373        response.reasoning = streamed.reasoning;
374    }
375
376    if response.reasoning_details.is_none() {
377        response.reasoning_details = streamed.reasoning_details;
378    }
379
380    if response.tool_references.is_empty() && !streamed.tool_references.is_empty() {
381        response.tool_references = streamed.tool_references;
382    }
383
384    if response.request_id.is_none() {
385        response.request_id = streamed.request_id;
386    }
387
388    if response.organization_id.is_none() {
389        response.organization_id = streamed.organization_id;
390    }
391}
392
393fn merge_final_response_metadata(response: &mut LLMResponse, final_response: &Value) {
394    if let Some(usage) = parse_responses_usage(final_response) {
395        response.usage = Some(usage);
396    }
397
398    if let Some(request_id) = final_response
399        .get("id")
400        .and_then(Value::as_str)
401        .or_else(|| final_response.get("request_id").and_then(Value::as_str))
402    {
403        response.request_id = Some(request_id.to_string());
404    }
405}
406
407fn parse_responses_usage(final_response: &Value) -> Option<crate::llm::provider::Usage> {
408    let usage_value = final_response.get("usage")?;
409    Some(crate::llm::provider::Usage {
410        prompt_tokens: usage_value
411            .get("input_tokens")
412            .or_else(|| usage_value.get("prompt_tokens"))
413            .and_then(Value::as_u64)
414            .and_then(|value| u32::try_from(value).ok())
415            .unwrap_or(0),
416        completion_tokens: usage_value
417            .get("output_tokens")
418            .or_else(|| usage_value.get("completion_tokens"))
419            .and_then(Value::as_u64)
420            .and_then(|value| u32::try_from(value).ok())
421            .unwrap_or(0),
422        total_tokens: usage_value
423            .get("total_tokens")
424            .and_then(Value::as_u64)
425            .and_then(|value| u32::try_from(value).ok())
426            .unwrap_or(0),
427        cached_prompt_tokens: None,
428        cache_creation_tokens: None,
429        cache_read_tokens: None,
430        iterations: None,
431    })
432}
433
434fn extract_error_message(payload: &Value) -> Option<String> {
435    payload
436        .get("error")
437        .and_then(|error| error.get("message"))
438        .and_then(Value::as_str)
439        .map(ToOwned::to_owned)
440        .or_else(|| {
441            payload
442                .get("response")
443                .and_then(|response| response.get("error"))
444                .and_then(|error| error.get("message"))
445                .and_then(Value::as_str)
446                .map(ToOwned::to_owned)
447        })
448}
449
450fn provider_error(provider_name: &str, message: impl Into<String>) -> LLMError {
451    let message = error_display::format_llm_error(provider_name, &message.into());
452    LLMError::Provider {
453        message,
454        metadata: None,
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::{
461        ResponsesNormalizedStreamOptions, ResponsesNormalizedStreamProcessor, provider_error,
462    };
463    use crate::llm::provider::{FinishReason, LLMResponse, NormalizedStreamEvent, ToolCall};
464    use serde_json::{Value, json};
465
466    fn options() -> ResponsesNormalizedStreamOptions {
467        ResponsesNormalizedStreamOptions {
468            provider_name: "TestProvider",
469            model: "gpt-5".to_string(),
470            emit_reasoning: true,
471        }
472    }
473
474    fn parse_response(value: Value) -> Result<LLMResponse, crate::llm::provider::LLMError> {
475        let content = value
476            .get("output")
477            .and_then(Value::as_array)
478            .and_then(|items| items.first())
479            .and_then(|item| item.get("content"))
480            .and_then(Value::as_array)
481            .and_then(|content| content.first())
482            .and_then(|item| item.get("text"))
483            .and_then(Value::as_str)
484            .map(ToOwned::to_owned);
485
486        Ok(LLMResponse {
487            content,
488            model: "gpt-5".to_string(),
489            finish_reason: FinishReason::Stop,
490            ..Default::default()
491        })
492    }
493
494    #[test]
495    fn text_delta_and_completed_yield_text_then_done() {
496        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
497
498        let events = processor
499            .handle_payload(json!({
500                "type": "response.output_text.delta",
501                "delta": "hello"
502            }))
503            .expect("text delta should parse");
504        assert!(matches!(
505            events.as_slice(),
506            [NormalizedStreamEvent::TextDelta { delta }] if delta == "hello"
507        ));
508
509        let completed_events = processor
510            .handle_payload(json!({
511                "type": "response.completed",
512                "response": {
513                    "output": [{
514                        "type": "message",
515                        "content": [{"type": "output_text", "text": "hello"}]
516                    }]
517                }
518            }))
519            .expect("completed event should parse");
520        assert!(completed_events.is_empty());
521
522        let finished = processor.finish().expect("finish should succeed");
523        assert!(matches!(
524            finished.as_slice(),
525            [NormalizedStreamEvent::Done { response }]
526                if response.content.as_deref() == Some("hello")
527        ));
528    }
529
530    #[test]
531    fn empty_final_response_uses_streamed_text_and_preserves_metadata() {
532        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), |value| {
533            let output = value
534                .get("output")
535                .and_then(Value::as_array)
536                .ok_or_else(|| provider_error("TestProvider", "missing output"))?;
537            if output.is_empty() {
538                return Err(provider_error("TestProvider", "No output in response"));
539            }
540            parse_response(value)
541        });
542
543        processor
544            .handle_payload(json!({
545                "type": "response.output_text.delta",
546                "delta": "streamed answer"
547            }))
548            .expect("text delta should parse");
549        processor
550            .handle_payload(json!({
551                "type": "response.completed",
552                "response": {
553                    "id": "resp_streamed",
554                    "output": [],
555                    "usage": {
556                        "input_tokens": 11,
557                        "output_tokens": 7,
558                        "total_tokens": 18
559                    }
560                }
561            }))
562            .expect("completed event should parse");
563
564        let finished = processor.finish().expect("finish should succeed");
565        let [
566            NormalizedStreamEvent::Usage { usage },
567            NormalizedStreamEvent::Done { response },
568        ] = finished.as_slice()
569        else {
570            panic!("expected usage then done");
571        };
572        assert_eq!(usage.prompt_tokens, 11);
573        assert_eq!(usage.completion_tokens, 7);
574        assert_eq!(usage.total_tokens, 18);
575        assert_eq!(response.content.as_deref(), Some("streamed answer"));
576        assert_eq!(response.request_id.as_deref(), Some("resp_streamed"));
577    }
578
579    #[test]
580    fn tool_call_deltas_emit_start_and_finish_with_assembled_tool_call() {
581        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), |_| {
582            Ok(LLMResponse {
583                model: "gpt-5".to_string(),
584                ..Default::default()
585            })
586        });
587
588        let started = processor
589            .handle_payload(json!({
590                "type": "response.output_item.added",
591                "output_index": 0,
592                "item": {
593                    "type": "function_call",
594                    "id": "call_1",
595                    "name": "search_workspace"
596                }
597            }))
598            .expect("output item metadata should parse");
599        assert!(matches!(
600            started.as_slice(),
601            [NormalizedStreamEvent::ToolCallStart { call_id, name }]
602                if call_id == "call_1" && name.as_deref() == Some("search_workspace")
603        ));
604
605        let first = processor
606            .handle_payload(json!({
607                "type": "response.function_call_arguments.delta",
608                "item_id": "call_1",
609                "delta": "{\"query\":\"vt"
610            }))
611            .expect("first tool delta should parse");
612        assert!(matches!(
613            first.as_slice(),
614            [NormalizedStreamEvent::ToolCallDelta { call_id: delta_call_id, delta }]
615            if delta_call_id == "call_1"
616                && delta == "{\"query\":\"vt"
617        ));
618
619        let second = processor
620            .handle_payload(json!({
621                "type": "response.function_call_arguments.delta",
622                "item_id": "call_1",
623                "delta": "code\"}"
624            }))
625            .expect("second tool delta should parse");
626        assert!(matches!(
627            second.as_slice(),
628            [NormalizedStreamEvent::ToolCallDelta { call_id, delta }]
629                if call_id == "call_1" && delta == "code\"}"
630        ));
631
632        let finished = processor.finish().expect("finish should succeed");
633        let response = match finished.as_slice() {
634            [NormalizedStreamEvent::Done { response }] => response,
635            _ => panic!("expected done event"),
636        };
637        let tool_calls = response
638            .tool_calls
639            .as_ref()
640            .expect("tool call should be assembled");
641        assert_eq!(
642            tool_calls,
643            &vec![ToolCall::function(
644                "call_1".to_string(),
645                "search_workspace".to_string(),
646                "{\"query\":\"vtcode\"}".to_string(),
647            )]
648        );
649    }
650
651    #[test]
652    fn refusal_delta_streams_visible_output() {
653        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
654
655        let events = processor
656            .handle_payload(json!({
657                "type": "response.refusal.delta",
658                "delta": "I can't help with that"
659            }))
660            .expect("refusal delta should parse");
661        assert!(matches!(
662            events.as_slice(),
663            [NormalizedStreamEvent::TextDelta { delta }]
664                if delta == "I can't help with that"
665        ));
666
667        let finished = processor.finish().expect("finish should succeed");
668        assert!(matches!(
669            finished.as_slice(),
670            [NormalizedStreamEvent::Done { response }]
671                if response.content.as_deref() == Some("I can't help with that")
672        ));
673    }
674
675    #[test]
676    fn failed_incomplete_and_error_events_surface_backend_message() {
677        for payload in [
678            json!({"type": "response.failed", "response": {"error": {"message": "failed"}}}),
679            json!({"type": "response.incomplete", "response": {"error": {"message": "incomplete"}}}),
680            json!({"type": "error", "error": {"message": "errored"}}),
681        ] {
682            let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
683            let error = processor
684                .handle_payload(payload)
685                .expect_err("error payload should fail");
686            assert!(
687                error.to_string().contains("failed")
688                    || error.to_string().contains("incomplete")
689                    || error.to_string().contains("errored")
690            );
691        }
692    }
693
694    #[test]
695    fn unknown_documented_events_are_ignored() {
696        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
697        let events = processor
698            .handle_payload(json!({
699                "type": "response.file_search_call.searching",
700                "query": "needle"
701            }))
702            .expect("unknown documented event should be ignored");
703        assert!(events.is_empty());
704        processor
705            .handle_payload(json!({
706                "type": "response.code_interpreter_call.code.delta",
707                "delta": "print(1)"
708            }))
709            .expect("code interpreter event should be ignored");
710
711        let finished = processor.finish().expect("finish should succeed");
712        assert!(matches!(
713            finished.as_slice(),
714            [NormalizedStreamEvent::Done { .. }]
715        ));
716    }
717
718    #[test]
719    fn missing_delta_reports_provider_error() {
720        let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
721        let error = processor
722            .handle_payload(json!({"type": "response.output_text.delta"}))
723            .expect_err("missing delta should fail");
724        assert_eq!(
725            error.to_string(),
726            provider_error("TestProvider", "missing delta").to_string()
727        );
728    }
729}