Skip to main content

suture_sse/
openai.rs

1use crate::extractor::{DeltaExtractor, Repair};
2use crate::target::{TargetKind, Targets};
3use serde_json::Value;
4
5/// OpenAI Chat Completions SSE extractor.
6pub struct OpenAi;
7
8impl DeltaExtractor for OpenAi {
9    fn on_event(&self, data: &[u8], targets: &mut Targets) {
10        if self.is_terminator(data) {
11            return;
12        }
13        let v: Value = match serde_json::from_slice(data) {
14            Ok(v) => v,
15            Err(_) => return,
16        };
17        if targets.id.is_none() {
18            if let Some(s) = v.get("id").and_then(Value::as_str) {
19                targets.id = Some(s.to_string());
20            }
21        }
22        if targets.model.is_none() {
23            if let Some(s) = v.get("model").and_then(Value::as_str) {
24                targets.model = Some(s.to_string());
25            }
26        }
27        let Some(choices) = v.get("choices").and_then(Value::as_array) else {
28            return;
29        };
30        for choice in choices {
31            let ci = choice.get("index").and_then(Value::as_u64).unwrap_or(0) as usize;
32            let Some(delta) = choice.get("delta") else {
33                continue;
34            };
35            if let Some(content) = delta.get("content").and_then(Value::as_str) {
36                targets.feed(
37                    TargetKind::Content { choice: ci },
38                    false,
39                    content.as_bytes(),
40                );
41            }
42            if let Some(tcs) = delta.get("tool_calls").and_then(Value::as_array) {
43                for tc in tcs {
44                    let ti = tc.get("index").and_then(Value::as_u64).unwrap_or(0) as usize;
45                    if let Some(args) = tc
46                        .get("function")
47                        .and_then(|f| f.get("arguments"))
48                        .and_then(Value::as_str)
49                    {
50                        targets.feed(
51                            TargetKind::ToolArgs {
52                                choice: ci,
53                                tool: ti,
54                            },
55                            true,
56                            args.as_bytes(),
57                        );
58                    }
59                }
60            }
61        }
62    }
63
64    fn is_terminator(&self, data: &[u8]) -> bool {
65        let start = data
66            .iter()
67            .position(|b| !b.is_ascii_whitespace())
68            .unwrap_or(data.len());
69        let end = data
70            .iter()
71            .rposition(|b| !b.is_ascii_whitespace())
72            .map_or(0, |i| i + 1);
73        &data[start..end] == b"[DONE]"
74    }
75
76    fn synthesize(&self, repairs: &[Repair], targets: &Targets, terminated: bool) -> Vec<u8> {
77        use crate::extractor::json_escape;
78        use crate::target::TargetKind;
79        let mut out = String::new();
80        let id = targets.id.as_deref().unwrap_or("suture-repair");
81        let model = targets.model.as_deref().unwrap_or("");
82        for r in repairs {
83            let esc = json_escape(&r.append);
84            let delta = match r.kind {
85                TargetKind::Content { choice } => {
86                    format!(r#"{{"index":{choice},"delta":{{"content":"{esc}"}}}}"#)
87                }
88                TargetKind::ToolArgs { choice, tool } => format!(
89                    r#"{{"index":{choice},"delta":{{"tool_calls":[{{"index":{tool},"function":{{"arguments":"{esc}"}}}}]}}}}"#
90                ),
91                TargetKind::Block { .. } => continue,
92            };
93            out.push_str(&format!(
94                "data: {{\"id\":\"{id}\",\"object\":\"chat.completion.chunk\",\"model\":\"{model}\",\"choices\":[{delta}]}}\n\n"
95            ));
96        }
97        if !terminated {
98            out.push_str("data: [DONE]\n\n");
99        }
100        out.into_bytes()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::target::{TargetKind, Targets};
108
109    #[test]
110    fn extracts_tool_arguments_fragments() {
111        let ext = OpenAi;
112        let mut t = Targets::new();
113        ext.on_event(
114            br#"{"id":"cmpl-1","model":"gpt-4","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"x\":"}}]}}]}"#,
115            &mut t,
116        );
117        ext.on_event(
118            br#"{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"12"}}]}}]}"#,
119            &mut t,
120        );
121        assert_eq!(t.id.as_deref(), Some("cmpl-1"));
122        assert_eq!(t.model.as_deref(), Some("gpt-4"));
123        let state = t.iter().next().expect("one target");
124        assert_eq!(state.kind, TargetKind::ToolArgs { choice: 0, tool: 0 });
125        let r = state.repair();
126        assert!(r.consistent && r.safe);
127        assert_eq!(r.append, b"}");
128    }
129
130    #[test]
131    fn done_is_terminator() {
132        let ext = OpenAi;
133        assert!(ext.is_terminator(b"[DONE]"));
134        assert!(!ext.is_terminator(br#"{"choices":[]}"#));
135    }
136
137    #[test]
138    fn terminator_ignores_only_surrounding_whitespace() {
139        let ext = OpenAi;
140        assert!(ext.is_terminator(b"  [DONE]\n"));
141        assert!(!ext.is_terminator(b"[DO NE]"));
142    }
143
144    #[test]
145    fn plain_text_content_not_repaired() {
146        let ext = OpenAi;
147        let mut t = Targets::new();
148        ext.on_event(
149            br#"{"choices":[{"index":0,"delta":{"content":"Hello, I am"}}]}"#,
150            &mut t,
151        );
152        let state = t.iter().next().unwrap();
153        assert!(!state.repairable(), "prose content must not be repaired");
154    }
155
156    #[test]
157    fn json_content_is_repaired() {
158        let ext = OpenAi;
159        let mut t = Targets::new();
160        ext.on_event(
161            br#"{"choices":[{"index":0,"delta":{"content":"{\"k\":\"v"}}]}"#,
162            &mut t,
163        );
164        let state = t.iter().next().unwrap();
165        assert!(state.repairable());
166        let r = state.repair();
167        assert_eq!(r.append, b"\"}");
168    }
169}