Skip to main content

synaps_cli/runtime/openai/
wire.rs

1//! SSE wire decoder for OpenAI-compatible streams. Ported from
2//! `openai-runtime::stream`, with `StreamEvent` renamed to `OaiEvent`.
3
4use super::types::{FunctionCall, OaiEvent, ToolCall};
5use serde::Deserialize;
6use std::collections::HashMap;
7
8// ─── Wire types ──────────────────────────────────────────────────────────────
9
10#[derive(Debug, Deserialize)]
11pub struct RawChunk {
12    #[serde(default)]
13    pub choices: Vec<RawChoice>,
14    #[serde(default)]
15    pub usage: Option<RawUsage>,
16}
17
18#[derive(Debug, Deserialize)]
19pub struct RawChoice {
20    #[serde(default)]
21    pub delta: Option<RawDelta>,
22    #[serde(default)]
23    pub finish_reason: Option<String>,
24}
25
26#[derive(Debug, Deserialize, Default)]
27pub struct RawDelta {
28    #[serde(default)]
29    pub role: Option<String>,
30    #[serde(default)]
31    pub content: Option<String>,
32    #[serde(default)]
33    pub tool_calls: Option<Vec<RawToolCallDelta>>,
34}
35
36#[derive(Debug, Deserialize)]
37pub struct RawToolCallDelta {
38    #[serde(default)]
39    pub index: Option<u32>,
40    #[serde(default)]
41    pub id: Option<String>,
42    #[serde(rename = "type", default)]
43    pub kind: Option<String>,
44    #[serde(default)]
45    pub function: Option<RawFunctionDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49pub struct RawFunctionDelta {
50    #[serde(default)]
51    pub name: Option<String>,
52    #[serde(default)]
53    pub arguments: Option<String>,
54}
55
56#[derive(Debug, Deserialize)]
57pub struct RawUsage {
58    #[serde(default)]
59    pub prompt_tokens: Option<u32>,
60    #[serde(default)]
61    pub completion_tokens: Option<u32>,
62    #[serde(default)]
63    pub prompt_tokens_details: Option<RawPromptTokensDetails>,
64}
65
66#[derive(Debug, Deserialize)]
67pub struct RawPromptTokensDetails {
68    #[serde(default)]
69    pub cached_tokens: Option<u32>,
70}
71
72/// Legacy text-only SSE line parser. Kept for simple use-cases; the main
73/// decoder path is `StreamDecoder`.
74pub fn parse_sse_line(line: &str) -> Option<OaiEvent> {
75    let line = line.trim_end_matches('\r');
76    if line.is_empty() || line.starts_with(':') {
77        return None;
78    }
79    let data = line.strip_prefix("data:")?.trim_start();
80    if data == "[DONE]" {
81        return Some(OaiEvent::Done);
82    }
83    let chunk: RawChunk = serde_json::from_str(data).ok()?;
84    let delta = chunk.choices.into_iter().next()?.delta?;
85    if let Some(role) = delta.role {
86        return Some(OaiEvent::RoleStart(role));
87    }
88    if let Some(content) = delta.content {
89        if !content.is_empty() {
90            return Some(OaiEvent::TextDelta(content));
91        }
92    }
93    None
94}
95
96// ─── Accumulator ─────────────────────────────────────────────────────────────
97
98#[derive(Debug, Default)]
99pub struct ToolCallAccumulator {
100    pub id: String,
101    pub name: String,
102    pub arguments: String,
103    started: bool,
104}
105
106#[derive(Debug)]
107pub struct StreamDecoder {
108    pub calls: HashMap<u32, ToolCallAccumulator>,
109    pub truncated: bool,
110    pub completed: bool,
111    role_emitted: bool,
112    done_emitted: bool,
113}
114
115impl Default for StreamDecoder {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl StreamDecoder {
122    pub fn new() -> Self {
123        Self {
124            calls: HashMap::new(),
125            truncated: false,
126            completed: false,
127            role_emitted: false,
128            done_emitted: false,
129        }
130    }
131
132    pub fn push_line<E: Extend<OaiEvent>>(&mut self, line: &str, sink: &mut E) {
133        let line = line.trim_end_matches('\r');
134        if line.is_empty() || line.starts_with(':') {
135            return;
136        }
137        let payload = match line.strip_prefix("data:").map(str::trim_start) {
138            Some(p) => p,
139            None => return,
140        };
141        if payload == "[DONE]" {
142            self.finish(sink);
143            return;
144        }
145        match serde_json::from_str::<RawChunk>(payload) {
146            Ok(chunk) => self.push_chunk(chunk, sink),
147            Err(e) => sink.extend(Some(OaiEvent::Warning(format!(
148                "sse parse error: {e}; payload={payload:?}"
149            )))),
150        }
151    }
152
153    fn push_chunk<E: Extend<OaiEvent>>(&mut self, chunk: RawChunk, sink: &mut E) {
154        for choice in chunk.choices {
155            let is_finish = choice.finish_reason.is_some();
156            if let Some(delta) = choice.delta {
157                if let Some(role) = delta.role {
158                    if !self.role_emitted {
159                        self.role_emitted = true;
160                        sink.extend(Some(OaiEvent::RoleStart(role)));
161                    }
162                }
163                if let Some(text) = delta.content {
164                    if !text.is_empty() {
165                        sink.extend(Some(OaiEvent::TextDelta(text)));
166                    }
167                }
168                // Process tool_calls — but de-dup finish-chunk re-sends.
169                // Some providers re-send the full tool_calls on the finish frame.
170                // We only skip if the chunk has finish_reason AND the tool_call
171                // has an id (indicating a full re-send, not a final argument delta).
172                if let Some(tcs) = delta.tool_calls {
173                    for tc in tcs {
174                        let is_resend = is_finish && tc.id.as_ref().is_some_and(|id| !id.is_empty());
175                        if !is_resend {
176                            self.apply_tool_call_delta(tc, sink);
177                        }
178                    }
179                }
180            }
181            if let Some(reason) = choice.finish_reason {
182                match reason.as_str() {
183                    "tool_calls" => self.flush_complete(sink),
184                    "length" => {
185                        if !self.calls.is_empty() {
186                            self.truncated = true;
187                            self.flush_complete(sink);
188                        }
189                    }
190                    "stop" | "content_filter" => {}
191                    other => sink.extend(Some(OaiEvent::Warning(format!(
192                        "unknown finish_reason: {other}"
193                    )))),
194                }
195            }
196        }
197        if let Some(u) = chunk.usage {
198            let cached = u.prompt_tokens_details
199                .and_then(|d| d.cached_tokens)
200                .unwrap_or(0);
201            sink.extend(Some(OaiEvent::Usage {
202                prompt_tokens: u.prompt_tokens.unwrap_or(0),
203                completion_tokens: u.completion_tokens.unwrap_or(0),
204                cached_tokens: cached,
205            }));
206        }
207    }
208
209    fn apply_tool_call_delta<E: Extend<OaiEvent>>(
210        &mut self,
211        tc: RawToolCallDelta,
212        sink: &mut E,
213    ) {
214        let idx = tc.index.unwrap_or(0);
215        let acc = self.calls.entry(idx).or_default();
216
217        if let Some(id) = tc.id {
218            if !id.is_empty() {
219                acc.id = id;
220            }
221        }
222        if let Some(f) = tc.function {
223            if let Some(n) = f.name {
224                if !n.is_empty() {
225                    acc.name = n;
226                }
227            }
228            if !acc.started && !acc.id.is_empty() && !acc.name.is_empty() {
229                acc.started = true;
230                sink.extend(Some(OaiEvent::ToolCallStart {
231                    index: idx,
232                    id: acc.id.clone(),
233                    name: acc.name.clone(),
234                }));
235            }
236            if let Some(args) = f.arguments {
237                if !args.is_empty() {
238                    acc.arguments.push_str(&args);
239                    sink.extend(Some(OaiEvent::ToolCallArgumentsDelta {
240                        index: idx,
241                        id: acc.id.clone(),
242                        delta: args,
243                    }));
244                }
245            }
246        } else if !acc.started && !acc.id.is_empty() && !acc.name.is_empty() {
247            acc.started = true;
248            sink.extend(Some(OaiEvent::ToolCallStart {
249                index: idx,
250                id: acc.id.clone(),
251                name: acc.name.clone(),
252            }));
253        }
254    }
255
256    pub fn finish<E: Extend<OaiEvent>>(&mut self, sink: &mut E) {
257        self.flush_complete(sink);
258        if !self.done_emitted {
259            self.done_emitted = true;
260            sink.extend(Some(OaiEvent::Done));
261        }
262    }
263
264    fn flush_complete<E: Extend<OaiEvent>>(&mut self, sink: &mut E) {
265        if self.completed || self.calls.is_empty() {
266            return;
267        }
268        self.completed = true;
269        let mut entries: Vec<(u32, ToolCallAccumulator)> = self.calls.drain().collect();
270        entries.sort_by_key(|(k, _)| *k);
271        let calls: Vec<ToolCall> = entries
272            .into_iter()
273            .map(|(_, acc)| ToolCall {
274                id: acc.id,
275                kind: "function".to_string(),
276                function: FunctionCall {
277                    name: acc.name,
278                    arguments: acc.arguments,
279                },
280            })
281            .collect();
282        sink.extend(Some(OaiEvent::ToolCallsComplete {
283            calls,
284            truncated: self.truncated,
285        }));
286    }
287}