1use crate::extractor::{DeltaExtractor, Repair};
2use crate::target::{TargetKind, Targets};
3use serde_json::Value;
4
5pub struct OpenAi;
7
8impl DeltaExtractor for OpenAi {
9 fn on_event(&self, data: &[u8], targets: &mut Targets) {
10 if self.is_terminator(data) {
11 return;
12 }
13 let v: Value = match serde_json::from_slice(data) {
14 Ok(v) => v,
15 Err(_) => return,
16 };
17 if targets.id.is_none() {
18 if let Some(s) = v.get("id").and_then(Value::as_str) {
19 targets.id = Some(s.to_string());
20 }
21 }
22 if targets.model.is_none() {
23 if let Some(s) = v.get("model").and_then(Value::as_str) {
24 targets.model = Some(s.to_string());
25 }
26 }
27 let Some(choices) = v.get("choices").and_then(Value::as_array) else {
28 return;
29 };
30 for choice in choices {
31 let ci = choice.get("index").and_then(Value::as_u64).unwrap_or(0) as usize;
32 let Some(delta) = choice.get("delta") else {
33 continue;
34 };
35 if let Some(content) = delta.get("content").and_then(Value::as_str) {
36 targets.feed(
37 TargetKind::Content { choice: ci },
38 false,
39 content.as_bytes(),
40 );
41 }
42 if let Some(tcs) = delta.get("tool_calls").and_then(Value::as_array) {
43 for tc in tcs {
44 let ti = tc.get("index").and_then(Value::as_u64).unwrap_or(0) as usize;
45 if let Some(args) = tc
46 .get("function")
47 .and_then(|f| f.get("arguments"))
48 .and_then(Value::as_str)
49 {
50 targets.feed(
51 TargetKind::ToolArgs {
52 choice: ci,
53 tool: ti,
54 },
55 true,
56 args.as_bytes(),
57 );
58 }
59 }
60 }
61 }
62 }
63
64 fn is_terminator(&self, data: &[u8]) -> bool {
65 let start = data
66 .iter()
67 .position(|b| !b.is_ascii_whitespace())
68 .unwrap_or(data.len());
69 let end = data
70 .iter()
71 .rposition(|b| !b.is_ascii_whitespace())
72 .map_or(0, |i| i + 1);
73 &data[start..end] == b"[DONE]"
74 }
75
76 fn synthesize(&self, repairs: &[Repair], targets: &Targets, terminated: bool) -> Vec<u8> {
77 use crate::extractor::json_escape;
78 use crate::target::TargetKind;
79 let mut out = String::new();
80 let id = targets.id.as_deref().unwrap_or("suture-repair");
81 let model = targets.model.as_deref().unwrap_or("");
82 for r in repairs {
83 let esc = json_escape(&r.append);
84 let delta = match r.kind {
85 TargetKind::Content { choice } => {
86 format!(r#"{{"index":{choice},"delta":{{"content":"{esc}"}}}}"#)
87 }
88 TargetKind::ToolArgs { choice, tool } => format!(
89 r#"{{"index":{choice},"delta":{{"tool_calls":[{{"index":{tool},"function":{{"arguments":"{esc}"}}}}]}}}}"#
90 ),
91 TargetKind::Block { .. } => continue,
92 };
93 out.push_str(&format!(
94 "data: {{\"id\":\"{id}\",\"object\":\"chat.completion.chunk\",\"model\":\"{model}\",\"choices\":[{delta}]}}\n\n"
95 ));
96 }
97 if !terminated {
98 out.push_str("data: [DONE]\n\n");
99 }
100 out.into_bytes()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::target::{TargetKind, Targets};
108
109 #[test]
110 fn extracts_tool_arguments_fragments() {
111 let ext = OpenAi;
112 let mut t = Targets::new();
113 ext.on_event(
114 br#"{"id":"cmpl-1","model":"gpt-4","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"x\":"}}]}}]}"#,
115 &mut t,
116 );
117 ext.on_event(
118 br#"{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"12"}}]}}]}"#,
119 &mut t,
120 );
121 assert_eq!(t.id.as_deref(), Some("cmpl-1"));
122 assert_eq!(t.model.as_deref(), Some("gpt-4"));
123 let state = t.iter().next().expect("one target");
124 assert_eq!(state.kind, TargetKind::ToolArgs { choice: 0, tool: 0 });
125 let r = state.repair();
126 assert!(r.consistent && r.safe);
127 assert_eq!(r.append, b"}");
128 }
129
130 #[test]
131 fn done_is_terminator() {
132 let ext = OpenAi;
133 assert!(ext.is_terminator(b"[DONE]"));
134 assert!(!ext.is_terminator(br#"{"choices":[]}"#));
135 }
136
137 #[test]
138 fn terminator_ignores_only_surrounding_whitespace() {
139 let ext = OpenAi;
140 assert!(ext.is_terminator(b" [DONE]\n"));
141 assert!(!ext.is_terminator(b"[DO NE]"));
142 }
143
144 #[test]
145 fn plain_text_content_not_repaired() {
146 let ext = OpenAi;
147 let mut t = Targets::new();
148 ext.on_event(
149 br#"{"choices":[{"index":0,"delta":{"content":"Hello, I am"}}]}"#,
150 &mut t,
151 );
152 let state = t.iter().next().unwrap();
153 assert!(!state.repairable(), "prose content must not be repaired");
154 }
155
156 #[test]
157 fn json_content_is_repaired() {
158 let ext = OpenAi;
159 let mut t = Targets::new();
160 ext.on_event(
161 br#"{"choices":[{"index":0,"delta":{"content":"{\"k\":\"v"}}]}"#,
162 &mut t,
163 );
164 let state = t.iter().next().unwrap();
165 assert!(state.repairable());
166 let r = state.repair();
167 assert_eq!(r.append, b"\"}");
168 }
169}