Skip to main content

think_scrubber/
lib.rs

1pub const OPEN_TAG_NAMES: &[&str] = &[
2    "think",
3    "thinking",
4    "reasoning",
5    "thought",
6    "REASONING_SCRATCHPAD",
7];
8
9// Since we want zero dependencies, we can just define OPEN_TAGS and CLOSE_TAGS statically.
10pub const OPEN_TAGS: &[&str] = &[
11    "<think>",
12    "<thinking>",
13    "<reasoning>",
14    "<thought>",
15    "<REASONING_SCRATCHPAD>",
16];
17
18pub const CLOSE_TAGS: &[&str] = &[
19    "</think>",
20    "</thinking>",
21    "</reasoning>",
22    "</thought>",
23    "</REASONING_SCRATCHPAD>",
24];
25
26pub const MAX_TAG_LEN: usize = 24; // length of "</REASONING_SCRATCHPAD>" + safety margin
27
28pub struct StreamingThinkScrubber {
29    in_block: bool,
30    buf: String,
31    last_emitted_ended_newline: bool,
32}
33
34impl Default for StreamingThinkScrubber {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl StreamingThinkScrubber {
41    pub fn new() -> Self {
42        Self {
43            in_block: false,
44            buf: String::new(),
45            last_emitted_ended_newline: true,
46        }
47    }
48
49    pub fn reset(&mut self) {
50        self.in_block = false;
51        self.buf.clear();
52        self.last_emitted_ended_newline = true;
53    }
54
55    pub fn feed(&mut self, text: &str) -> String {
56        if text.is_empty() {
57            return String::new();
58        }
59        let mut buf = format!("{}{}", self.buf, text);
60        self.buf.clear();
61        let mut out: Vec<String> = Vec::new();
62
63        while !buf.is_empty() {
64            if self.in_block {
65                if let Some((close_idx, close_len)) = self.find_first_tag(&buf, CLOSE_TAGS) {
66                    buf = buf[close_idx + close_len..].to_string();
67                    self.in_block = false;
68                } else {
69                    let held = self.max_partial_suffix(&buf, CLOSE_TAGS);
70                    if held > 0 {
71                        self.buf = buf[buf.len() - held..].to_string();
72                    } else {
73                        self.buf.clear();
74                    }
75                    return out.concat();
76                }
77            } else {
78                let pair = self.find_earliest_closed_pair(&buf);
79                let open_opt = self.find_open_at_boundary(&buf, &out);
80
81                match (pair, open_opt) {
82                    (Some((p_start, p_end)), open_val) if open_val.is_none() || p_start <= open_val.unwrap().0 => {
83                        let mut preceding = buf[..p_start].to_string();
84                        if !preceding.is_empty() {
85                            preceding = self.strip_orphan_close_tags(&preceding);
86                            if !preceding.is_empty() {
87                                self.last_emitted_ended_newline = preceding.ends_with('\n');
88                                out.push(preceding);
89                            }
90                        }
91                        buf = buf[p_end..].to_string();
92                        continue;
93                    }
94                    (_, Some((open_idx, open_len))) => {
95                        let mut preceding = buf[..open_idx].to_string();
96                        if !preceding.is_empty() {
97                            preceding = self.strip_orphan_close_tags(&preceding);
98                            if !preceding.is_empty() {
99                                self.last_emitted_ended_newline = preceding.ends_with('\n');
100                                out.push(preceding);
101                            }
102                        }
103                        self.in_block = true;
104                        buf = buf[open_idx + open_len..].to_string();
105                        continue;
106                    }
107                    _ => {
108                        let held_open = self.max_partial_suffix(&buf, OPEN_TAGS);
109                        let held_close = self.max_partial_suffix(&buf, CLOSE_TAGS);
110                        let held = held_open.max(held_close);
111
112                        let emit_text = if held > 0 {
113                            let (emit, hold) = buf.split_at(buf.len() - held);
114                            self.buf = hold.to_string();
115                            emit.to_string()
116                        } else {
117                            self.buf.clear();
118                            buf.clone()
119                        };
120
121                        if !emit_text.is_empty() {
122                            let clean_emit = self.strip_orphan_close_tags(&emit_text);
123                            if !clean_emit.is_empty() {
124                                self.last_emitted_ended_newline = clean_emit.ends_with('\n');
125                                out.push(clean_emit);
126                            }
127                        }
128                        return out.concat();
129                    }
130                }
131            }
132        }
133        out.concat()
134    }
135
136    pub fn flush(&mut self) -> String {
137        if self.in_block {
138            self.buf.clear();
139            self.in_block = false;
140            return String::new();
141        }
142        let tail = self.buf.clone();
143        self.buf.clear();
144        if tail.is_empty() {
145            return String::new();
146        }
147        let clean_tail = self.strip_orphan_close_tags(&tail);
148        if !clean_tail.is_empty() {
149            self.last_emitted_ended_newline = clean_tail.ends_with('\n');
150        }
151        clean_tail
152    }
153
154    fn find_first_tag(&self, buf: &str, tags: &[&str]) -> Option<(usize, usize)> {
155        let buf_lower = buf.to_ascii_lowercase();
156        let mut best_idx = None;
157        let mut best_len = 0;
158        for tag in tags {
159            let tag_lower = tag.to_ascii_lowercase();
160            if let Some(idx) = buf_lower.find(&tag_lower) {
161                if best_idx.is_none() || idx < best_idx.unwrap() {
162                    best_idx = Some(idx);
163                    best_len = tag.len();
164                }
165            }
166        }
167        best_idx.map(|idx| (idx, best_len))
168    }
169
170    fn find_earliest_closed_pair(&self, buf: &str) -> Option<(usize, usize)> {
171        let buf_lower = buf.to_ascii_lowercase();
172        let mut best: Option<(usize, usize)> = None;
173        for (open_tag, close_tag) in OPEN_TAGS.iter().zip(CLOSE_TAGS.iter()) {
174            let open_lower = open_tag.to_ascii_lowercase();
175            let close_lower = close_tag.to_ascii_lowercase();
176
177            if let Some(open_idx) = buf_lower.find(&open_lower) {
178                if let Some(close_idx) = buf_lower[open_idx + open_lower.len()..].find(&close_lower) {
179                    let actual_close_idx = open_idx + open_lower.len() + close_idx;
180                    let end_idx = actual_close_idx + close_lower.len();
181                    if best.is_none() || open_idx < best.unwrap().0 {
182                        best = Some((open_idx, end_idx));
183                    }
184                }
185            }
186        }
187        best
188    }
189
190    fn find_open_at_boundary(&self, buf: &str, already_emitted: &[String]) -> Option<(usize, usize)> {
191        let buf_lower = buf.to_ascii_lowercase();
192        let mut best_idx = None;
193        let mut best_len = 0;
194        for tag in OPEN_TAGS.iter() {
195            let tag_lower = tag.to_ascii_lowercase();
196            let mut search_start = 0;
197            while let Some(idx) = buf_lower[search_start..].find(&tag_lower) {
198                let actual_idx = search_start + idx;
199                if self.is_block_boundary(buf, actual_idx, already_emitted) {
200                    if best_idx.is_none() || actual_idx < best_idx.unwrap() {
201                        best_idx = Some(actual_idx);
202                        best_len = tag.len();
203                    }
204                    break;
205                }
206                search_start = actual_idx + 1;
207            }
208        }
209        best_idx.map(|idx| (idx, best_len))
210    }
211
212    fn is_block_boundary(&self, buf: &str, idx: usize, already_emitted: &[String]) -> bool {
213        if idx == 0 {
214            if !already_emitted.is_empty() {
215                return already_emitted.last().unwrap().ends_with('\n');
216            }
217            return self.last_emitted_ended_newline;
218        }
219        // Since we are indexing the original string using ASCII-derived indices, let's find the newline index
220        let preceding = &buf[..idx];
221        if let Some(last_nl) = preceding.rfind('\n') {
222            preceding[last_nl + 1..].trim().is_empty()
223        } else {
224            let prior_newline = if !already_emitted.is_empty() {
225                already_emitted.last().unwrap().ends_with('\n')
226            } else {
227                self.last_emitted_ended_newline
228            };
229            prior_newline && preceding.trim().is_empty()
230        }
231    }
232
233    fn max_partial_suffix(&self, buf: &str, tags: &[&str]) -> usize {
234        if buf.is_empty() {
235            return 0;
236        }
237        let buf_lower = buf.to_ascii_lowercase();
238        let max_check = buf_lower.len().min(MAX_TAG_LEN - 1);
239        for i in (1..=max_check).rev() {
240            let suffix = &buf_lower[buf_lower.len() - i..];
241            for tag in tags {
242                let tag_lower = tag.to_ascii_lowercase();
243                if tag_lower.len() > i && tag_lower.starts_with(suffix) {
244                    return i;
245                }
246            }
247        }
248        0
249    }
250
251    fn strip_orphan_close_tags(&self, text: &str) -> String {
252        if !text.contains("</") {
253            return text.to_string();
254        }
255        let chars: Vec<char> = text.chars().collect();
256        let chars_lower: Vec<char> = text.to_ascii_lowercase().chars().collect();
257        let mut out = String::new();
258        let mut i = 0;
259        while i < chars.len() {
260            let mut matched = false;
261            if i + 2 <= chars.len() && chars_lower[i] == '<' && chars_lower[i+1] == '/' {
262                for tag in CLOSE_TAGS.iter() {
263                    let tag_chars: Vec<char> = tag.to_ascii_lowercase().chars().collect();
264                    let tag_len = tag_chars.len();
265                    if i + tag_len <= chars.len() && chars_lower[i..i+tag_len] == tag_chars {
266                        let mut j = i + tag_len;
267                        while j < chars.len() && (chars[j] == ' ' || chars[j] == '\t' || chars[j] == '\n' || chars[j] == '\r') {
268                            j += 1;
269                        }
270                        i = j;
271                        matched = true;
272                        break;
273                    }
274                }
275            }
276            if !matched {
277                out.push(chars[i]);
278                i += 1;
279            }
280        }
281        out
282    }
283}
284
285pub fn scrub_string(text: &str) -> String {
286    let mut scrubber = StreamingThinkScrubber::new();
287    let body = scrubber.feed(text);
288    let tail = scrubber.flush();
289    format!("{}{}", body, tail)
290}