vtcode_core/llm/providers/shared/
responses_stream.rs1use crate::llm::error_display;
2use crate::llm::provider::{LLMError, LLMNormalizedStream, LLMResponse, NormalizedStreamEvent};
3use crate::llm::providers::shared::{extract_data_payload, find_sse_boundary};
4use async_stream::try_stream;
5use futures::StreamExt;
6use hashbrown::{HashMap, HashSet};
7use serde_json::{Value, json};
8
9use super::StreamAggregator;
10
11pub struct ResponsesNormalizedStreamOptions {
12 pub provider_name: &'static str,
13 pub model: String,
14 pub emit_reasoning: bool,
15}
16
17struct ResponsesNormalizedStreamProcessor<P> {
18 options: ResponsesNormalizedStreamOptions,
19 parse_final_response: P,
20 aggregator: StreamAggregator,
21 seen_tool_calls: HashSet<String>,
22 tool_call_indexes: HashMap<String, usize>,
23 tool_call_names: HashMap<String, String>,
24 next_tool_call_index: usize,
25 final_response: Option<Value>,
26 done: bool,
27}
28
29impl<P> ResponsesNormalizedStreamProcessor<P>
30where
31 P: Fn(Value) -> Result<LLMResponse, LLMError>,
32{
33 fn new(options: ResponsesNormalizedStreamOptions, parse_final_response: P) -> Self {
34 Self {
35 aggregator: StreamAggregator::new(options.model.clone()),
36 options,
37 parse_final_response,
38 seen_tool_calls: HashSet::new(),
39 tool_call_indexes: HashMap::new(),
40 tool_call_names: HashMap::new(),
41 next_tool_call_index: 0,
42 final_response: None,
43 done: false,
44 }
45 }
46
47 fn is_done(&self) -> bool {
48 self.done
49 }
50
51 fn handle_payload(&mut self, payload: Value) -> Result<Vec<NormalizedStreamEvent>, LLMError> {
52 let mut events = Vec::new();
53
54 if let Some(usage) = payload.get("usage").cloned()
55 && let Ok(usage) = serde_json::from_value(usage)
56 {
57 self.aggregator.set_usage(usage);
58 }
59
60 let event_type = payload.get("type").and_then(Value::as_str).unwrap_or("");
61 match event_type {
62 "response.output_text.delta" => {
63 let delta = payload
64 .get("delta")
65 .and_then(Value::as_str)
66 .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
67 for event in self.aggregator.handle_content(delta) {
68 match event {
69 crate::llm::provider::LLMStreamEvent::Token { delta } => {
70 events.push(NormalizedStreamEvent::TextDelta { delta });
71 }
72 crate::llm::provider::LLMStreamEvent::Reasoning { delta }
73 if self.options.emit_reasoning =>
74 {
75 events.push(NormalizedStreamEvent::ReasoningDelta { delta });
76 }
77 _ => {}
78 }
79 }
80 }
81 "response.refusal.delta" => {
82 let delta = payload
83 .get("delta")
84 .and_then(Value::as_str)
85 .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
86 if !delta.is_empty() {
87 self.aggregator.content.push_str(delta);
88 events.push(NormalizedStreamEvent::TextDelta {
89 delta: delta.to_string(),
90 });
91 }
92 }
93 "response.reasoning_text.delta"
94 | "response.reasoning_summary_text.delta"
95 | "response.reasoning_content.delta" => {
96 if self.options.emit_reasoning
97 && let Some(delta) = payload.get("delta").and_then(Value::as_str)
98 && let Some(delta) = self.aggregator.handle_reasoning(delta)
99 {
100 events.push(NormalizedStreamEvent::ReasoningDelta { delta });
101 }
102 }
103 "response.output_item.added" | "response.output_item.done" => {
104 if let Some(item) = payload.get("item") {
105 let tool_call = self.capture_tool_call_metadata(
106 item,
107 payload
108 .get("output_index")
109 .and_then(Value::as_u64)
110 .map(|value| value as usize),
111 );
112 if let Some((call_id, name)) = tool_call {
113 self.push_tool_call_start(&mut events, call_id, Some(name));
114 }
115 }
116 }
117 "response.function_call_arguments.delta" => {
118 let delta = payload
119 .get("delta")
120 .and_then(Value::as_str)
121 .ok_or_else(|| provider_error(self.options.provider_name, "missing delta"))?;
122 let call_id = payload
123 .get("item_id")
124 .and_then(Value::as_str)
125 .or_else(|| payload.get("call_id").and_then(Value::as_str))
126 .filter(|value| !value.is_empty())
127 .map(ToOwned::to_owned)
128 .unwrap_or_else(|| format!("tool_call_{}", self.next_tool_call_index));
129 let index = self.resolve_tool_call_index(
130 &call_id,
131 payload
132 .get("output_index")
133 .and_then(Value::as_u64)
134 .map(|value| value as usize),
135 );
136
137 let name = self.tool_call_names.get(&call_id).cloned();
138 self.push_tool_call_start(&mut events, call_id.clone(), name);
139
140 if !delta.is_empty() {
141 self.aggregator.handle_tool_calls(&[json!({
142 "index": index,
143 "id": call_id,
144 "function": {
145 "arguments": delta,
146 }
147 })]);
148 events.push(NormalizedStreamEvent::ToolCallDelta {
149 call_id,
150 delta: delta.to_string(),
151 });
152 }
153 }
154 "response.completed" => {
155 if let Some(response) = payload.get("response") {
156 self.final_response = Some(response.clone());
157 }
158 self.done = true;
159 }
160 "response.failed" | "response.incomplete" | "error" => {
161 let message = extract_error_message(&payload)
162 .unwrap_or_else(|| "unknown error from responses stream".to_string());
163 return Err(provider_error(self.options.provider_name, message));
164 }
165 _ => {}
166 }
167
168 Ok(events)
169 }
170
171 fn finish(self) -> Result<Vec<NormalizedStreamEvent>, LLMError> {
172 let streamed = self.aggregator.finalize();
173 let mut response = if let Some(final_response) = self.final_response {
174 (self.parse_final_response)(final_response)?
175 } else {
176 streamed.clone()
177 };
178
179 merge_streamed_response(&mut response, streamed);
180
181 let mut events = Vec::new();
182 if let Some(usage) = response.usage.clone() {
183 events.push(NormalizedStreamEvent::Usage { usage });
184 }
185 events.push(NormalizedStreamEvent::Done {
186 response: Box::new(response),
187 });
188 Ok(events)
189 }
190
191 fn capture_tool_call_metadata(
192 &mut self,
193 item: &Value,
194 output_index: Option<usize>,
195 ) -> Option<(String, String)> {
196 let item_type = item.get("type").and_then(Value::as_str).unwrap_or("");
197 if item_type != "function_call" {
198 return None;
199 }
200
201 let call_id = item
202 .get("id")
203 .and_then(Value::as_str)
204 .or_else(|| item.get("call_id").and_then(Value::as_str))
205 .filter(|value| !value.is_empty());
206 let name = item.get("name").and_then(Value::as_str).or_else(|| {
207 item.get("function")
208 .and_then(|function| function.get("name"))
209 .and_then(Value::as_str)
210 });
211 if let (Some(call_id), Some(name)) = (call_id, name) {
212 self.tool_call_names
213 .entry(call_id.to_string())
214 .or_insert_with(|| name.to_string());
215 let index = self.resolve_tool_call_index(call_id, output_index);
216 self.aggregator.handle_tool_calls(&[json!({
217 "index": index,
218 "id": call_id,
219 "function": {
220 "name": name,
221 }
222 })]);
223 return Some((call_id.to_string(), name.to_string()));
224 }
225
226 None
227 }
228
229 fn push_tool_call_start(
230 &mut self,
231 events: &mut Vec<NormalizedStreamEvent>,
232 call_id: String,
233 name: Option<String>,
234 ) {
235 if self.seen_tool_calls.insert(call_id.clone()) {
236 events.push(NormalizedStreamEvent::ToolCallStart { call_id, name });
237 }
238 }
239
240 fn resolve_tool_call_index(&mut self, call_id: &str, output_index: Option<usize>) -> usize {
241 if let Some(index) = output_index {
242 self.tool_call_indexes.insert(call_id.to_string(), index);
243 self.next_tool_call_index = self.next_tool_call_index.max(index + 1);
244 return index;
245 }
246
247 if let Some(index) = self.tool_call_indexes.get(call_id).copied() {
248 return index;
249 }
250
251 let index = self.next_tool_call_index;
252 self.tool_call_indexes.insert(call_id.to_string(), index);
253 self.next_tool_call_index += 1;
254 index
255 }
256}
257
258pub fn create_responses_normalized_stream<P>(
259 response: reqwest::Response,
260 options: ResponsesNormalizedStreamOptions,
261 parse_final_response: P,
262) -> LLMNormalizedStream
263where
264 P: Fn(Value) -> Result<LLMResponse, LLMError> + Send + 'static,
265{
266 let stream = try_stream! {
267 let provider_name = options.provider_name;
268 let mut processor = ResponsesNormalizedStreamProcessor::new(options, parse_final_response);
269 let mut body_stream = response.bytes_stream();
270 let mut buffer = String::new();
271
272 while let Some(chunk_result) = body_stream.next().await {
273 let chunk = chunk_result.map_err(|err| provider_error(
274 provider_name,
275 format!("streaming error: {err}"),
276 ))?;
277 buffer.push_str(&String::from_utf8_lossy(&chunk));
278
279 while let Some((split_idx, delimiter_len)) = find_sse_boundary(&buffer) {
280 let event = buffer[..split_idx].to_string();
281 buffer.drain(..split_idx + delimiter_len);
282
283 if let Some(data_payload) = extract_data_payload(&event) {
284 let trimmed_payload = data_payload.trim();
285 if trimmed_payload.is_empty() || trimmed_payload == "[DONE]" {
286 continue;
287 }
288
289 let payload: Value = serde_json::from_str(trimmed_payload).map_err(|err| {
290 provider_error(provider_name, format!("invalid stream payload: {err}"))
291 })?;
292
293 for event in processor.handle_payload(payload)? {
294 yield event;
295 }
296
297 if processor.is_done() {
298 break;
299 }
300 }
301 }
302
303 if processor.is_done() {
304 break;
305 }
306 }
307
308 for event in processor.finish()? {
309 yield event;
310 }
311 };
312
313 Box::pin(stream)
314}
315
316fn merge_streamed_response(response: &mut LLMResponse, streamed: LLMResponse) {
317 if response.content.as_deref().unwrap_or_default().is_empty() {
318 response.content = streamed.content;
319 } else if let (Some(content), Some(streamed_content)) =
320 (&mut response.content, streamed.content)
321 && !streamed_content.is_empty()
322 && !content.contains(&streamed_content)
323 {
324 content.push_str(&streamed_content);
325 }
326
327 if response.tool_calls.is_none() {
328 response.tool_calls = streamed.tool_calls;
329 }
330
331 if response.usage.is_none() {
332 response.usage = streamed.usage;
333 }
334
335 if response.reasoning.is_none() {
336 response.reasoning = streamed.reasoning;
337 }
338
339 if response.reasoning_details.is_none() {
340 response.reasoning_details = streamed.reasoning_details;
341 }
342
343 if response.tool_references.is_empty() && !streamed.tool_references.is_empty() {
344 response.tool_references = streamed.tool_references;
345 }
346
347 if response.request_id.is_none() {
348 response.request_id = streamed.request_id;
349 }
350
351 if response.organization_id.is_none() {
352 response.organization_id = streamed.organization_id;
353 }
354}
355
356fn extract_error_message(payload: &Value) -> Option<String> {
357 payload
358 .get("error")
359 .and_then(|error| error.get("message"))
360 .and_then(Value::as_str)
361 .map(ToOwned::to_owned)
362 .or_else(|| {
363 payload
364 .get("response")
365 .and_then(|response| response.get("error"))
366 .and_then(|error| error.get("message"))
367 .and_then(Value::as_str)
368 .map(ToOwned::to_owned)
369 })
370}
371
372fn provider_error(provider_name: &str, message: impl Into<String>) -> LLMError {
373 let message = error_display::format_llm_error(provider_name, &message.into());
374 LLMError::Provider {
375 message,
376 metadata: None,
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::{
383 ResponsesNormalizedStreamOptions, ResponsesNormalizedStreamProcessor, provider_error,
384 };
385 use crate::llm::provider::{FinishReason, LLMResponse, NormalizedStreamEvent, ToolCall};
386 use serde_json::{Value, json};
387
388 fn options() -> ResponsesNormalizedStreamOptions {
389 ResponsesNormalizedStreamOptions {
390 provider_name: "TestProvider",
391 model: "gpt-5".to_string(),
392 emit_reasoning: true,
393 }
394 }
395
396 fn parse_response(value: Value) -> Result<LLMResponse, crate::llm::provider::LLMError> {
397 let content = value
398 .get("output")
399 .and_then(Value::as_array)
400 .and_then(|items| items.first())
401 .and_then(|item| item.get("content"))
402 .and_then(Value::as_array)
403 .and_then(|content| content.first())
404 .and_then(|item| item.get("text"))
405 .and_then(Value::as_str)
406 .map(ToOwned::to_owned);
407
408 Ok(LLMResponse {
409 content,
410 model: "gpt-5".to_string(),
411 finish_reason: FinishReason::Stop,
412 ..Default::default()
413 })
414 }
415
416 #[test]
417 fn text_delta_and_completed_yield_text_then_done() {
418 let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
419
420 let events = processor
421 .handle_payload(json!({
422 "type": "response.output_text.delta",
423 "delta": "hello"
424 }))
425 .expect("text delta should parse");
426 assert!(matches!(
427 events.as_slice(),
428 [NormalizedStreamEvent::TextDelta { delta }] if delta == "hello"
429 ));
430
431 let completed_events = processor
432 .handle_payload(json!({
433 "type": "response.completed",
434 "response": {
435 "output": [{
436 "type": "message",
437 "content": [{"type": "output_text", "text": "hello"}]
438 }]
439 }
440 }))
441 .expect("completed event should parse");
442 assert!(completed_events.is_empty());
443
444 let finished = processor.finish().expect("finish should succeed");
445 assert!(matches!(
446 finished.as_slice(),
447 [NormalizedStreamEvent::Done { response }]
448 if response.content.as_deref() == Some("hello")
449 ));
450 }
451
452 #[test]
453 fn tool_call_deltas_emit_start_and_finish_with_assembled_tool_call() {
454 let mut processor = ResponsesNormalizedStreamProcessor::new(options(), |_| {
455 Ok(LLMResponse {
456 model: "gpt-5".to_string(),
457 ..Default::default()
458 })
459 });
460
461 let started = processor
462 .handle_payload(json!({
463 "type": "response.output_item.added",
464 "output_index": 0,
465 "item": {
466 "type": "function_call",
467 "id": "call_1",
468 "name": "search_workspace"
469 }
470 }))
471 .expect("output item metadata should parse");
472 assert!(matches!(
473 started.as_slice(),
474 [NormalizedStreamEvent::ToolCallStart { call_id, name }]
475 if call_id == "call_1" && name.as_deref() == Some("search_workspace")
476 ));
477
478 let first = processor
479 .handle_payload(json!({
480 "type": "response.function_call_arguments.delta",
481 "item_id": "call_1",
482 "delta": "{\"query\":\"vt"
483 }))
484 .expect("first tool delta should parse");
485 assert!(matches!(
486 first.as_slice(),
487 [NormalizedStreamEvent::ToolCallDelta { call_id: delta_call_id, delta }]
488 if delta_call_id == "call_1"
489 && delta == "{\"query\":\"vt"
490 ));
491
492 let second = processor
493 .handle_payload(json!({
494 "type": "response.function_call_arguments.delta",
495 "item_id": "call_1",
496 "delta": "code\"}"
497 }))
498 .expect("second tool delta should parse");
499 assert!(matches!(
500 second.as_slice(),
501 [NormalizedStreamEvent::ToolCallDelta { call_id, delta }]
502 if call_id == "call_1" && delta == "code\"}"
503 ));
504
505 let finished = processor.finish().expect("finish should succeed");
506 let response = match finished.as_slice() {
507 [NormalizedStreamEvent::Done { response }] => response,
508 _ => panic!("expected done event"),
509 };
510 let tool_calls = response
511 .tool_calls
512 .as_ref()
513 .expect("tool call should be assembled");
514 assert_eq!(
515 tool_calls,
516 &vec![ToolCall::function(
517 "call_1".to_string(),
518 "search_workspace".to_string(),
519 "{\"query\":\"vtcode\"}".to_string(),
520 )]
521 );
522 }
523
524 #[test]
525 fn refusal_delta_streams_visible_output() {
526 let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
527
528 let events = processor
529 .handle_payload(json!({
530 "type": "response.refusal.delta",
531 "delta": "I can't help with that"
532 }))
533 .expect("refusal delta should parse");
534 assert!(matches!(
535 events.as_slice(),
536 [NormalizedStreamEvent::TextDelta { delta }]
537 if delta == "I can't help with that"
538 ));
539
540 let finished = processor.finish().expect("finish should succeed");
541 assert!(matches!(
542 finished.as_slice(),
543 [NormalizedStreamEvent::Done { response }]
544 if response.content.as_deref() == Some("I can't help with that")
545 ));
546 }
547
548 #[test]
549 fn failed_incomplete_and_error_events_surface_backend_message() {
550 for payload in [
551 json!({"type": "response.failed", "response": {"error": {"message": "failed"}}}),
552 json!({"type": "response.incomplete", "response": {"error": {"message": "incomplete"}}}),
553 json!({"type": "error", "error": {"message": "errored"}}),
554 ] {
555 let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
556 let error = processor
557 .handle_payload(payload)
558 .expect_err("error payload should fail");
559 assert!(
560 error.to_string().contains("failed")
561 || error.to_string().contains("incomplete")
562 || error.to_string().contains("errored")
563 );
564 }
565 }
566
567 #[test]
568 fn unknown_documented_events_are_ignored() {
569 let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
570 let events = processor
571 .handle_payload(json!({
572 "type": "response.file_search_call.searching",
573 "query": "needle"
574 }))
575 .expect("unknown documented event should be ignored");
576 assert!(events.is_empty());
577 processor
578 .handle_payload(json!({
579 "type": "response.code_interpreter_call.code.delta",
580 "delta": "print(1)"
581 }))
582 .expect("code interpreter event should be ignored");
583
584 let finished = processor.finish().expect("finish should succeed");
585 assert!(matches!(
586 finished.as_slice(),
587 [NormalizedStreamEvent::Done { .. }]
588 ));
589 }
590
591 #[test]
592 fn missing_delta_reports_provider_error() {
593 let mut processor = ResponsesNormalizedStreamProcessor::new(options(), parse_response);
594 let error = processor
595 .handle_payload(json!({"type": "response.output_text.delta"}))
596 .expect_err("missing delta should fail");
597 assert_eq!(
598 error.to_string(),
599 provider_error("TestProvider", "missing delta").to_string()
600 );
601 }
602}