trustformers_tokenizers/
streaming.rs1use anyhow::Result as AnyhowResult;
2use std::io::{BufRead, BufReader, Read};
3use trustformers_core::errors::Result;
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6pub struct StreamingTokenizer<T: Tokenizer> {
8 tokenizer: T,
9 buffer_size: usize,
10 overlap_size: usize,
11 max_chunk_length: Option<usize>,
12}
13
14impl<T: Tokenizer> StreamingTokenizer<T> {
15 pub fn new(tokenizer: T) -> Self {
17 Self {
18 tokenizer,
19 buffer_size: 8192, overlap_size: 256, max_chunk_length: None,
22 }
23 }
24
25 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
27 self.buffer_size = buffer_size;
28 self
29 }
30
31 pub fn with_overlap_size(mut self, overlap_size: usize) -> Self {
33 self.overlap_size = overlap_size;
34 self
35 }
36
37 pub fn with_max_chunk_length(mut self, max_length: usize) -> Self {
39 self.max_chunk_length = Some(max_length);
40 self
41 }
42
43 pub fn process_stream<R: Read>(&self, reader: R) -> Result<Vec<TokenizedInput>> {
45 let mut buf_reader = BufReader::with_capacity(self.buffer_size, reader);
46 let mut chunks = Vec::new();
47 let mut buffer = String::new();
48 let mut previous_overlap = String::new();
49
50 loop {
51 buffer.clear();
52 let bytes_read = buf_reader.read_line(&mut buffer).map_err(|e| {
53 trustformers_core::errors::TrustformersError::other(format!("I/O error: {}", e))
54 })?;
55
56 if bytes_read == 0 {
57 break; }
59
60 let full_text = if previous_overlap.is_empty() {
62 buffer.clone()
63 } else {
64 format!("{}{}", previous_overlap, buffer)
65 };
66
67 let tokenized = self.tokenize_chunk(&full_text)?;
69 chunks.push(tokenized);
70
71 if full_text.len() > self.overlap_size {
73 previous_overlap = full_text[full_text.len() - self.overlap_size..].to_string();
74 } else {
75 previous_overlap.clear();
76 }
77 }
78
79 Ok(chunks)
80 }
81
82 pub fn process_text(&self, text: &str) -> Result<Vec<TokenizedInput>> {
84 let mut chunks = Vec::new();
85 let mut start = 0;
86 let chunk_size = self.buffer_size;
87
88 if text.is_empty() {
90 let empty_chunk = self.tokenize_chunk("")?;
91 chunks.push(empty_chunk);
92 return Ok(chunks);
93 }
94
95 while start < text.len() {
96 let end = std::cmp::min(start + chunk_size, text.len());
97 let mut chunk_end = end;
98
99 if end < text.len() {
101 if let Some(last_space) = text[start..end].rfind(' ') {
102 chunk_end = start + last_space;
103 }
104 }
105
106 if chunk_end <= start {
108 chunk_end = std::cmp::min(start + 1, text.len());
109 }
110
111 let chunk_text = &text[start..chunk_end];
112 let tokenized = self.tokenize_chunk(chunk_text)?;
113 chunks.push(tokenized);
114
115 let next_start = if chunk_end > self.overlap_size {
117 chunk_end - self.overlap_size
118 } else {
119 chunk_end
120 };
121
122 start = std::cmp::max(next_start, start + 1);
124 }
125
126 Ok(chunks)
127 }
128
129 pub fn process_lines<I>(&self, lines: I) -> Result<Vec<TokenizedInput>>
131 where
132 I: Iterator<Item = String>,
133 {
134 let mut chunks = Vec::new();
135 let mut current_chunk = String::new();
136
137 for line in lines {
138 if !current_chunk.is_empty() {
140 current_chunk.push('\n');
141 }
142 current_chunk.push_str(&line);
143
144 if current_chunk.len() >= self.buffer_size {
146 let tokenized = self.tokenize_chunk(¤t_chunk)?;
147 chunks.push(tokenized);
148
149 if current_chunk.len() > self.overlap_size {
151 current_chunk =
152 current_chunk[current_chunk.len() - self.overlap_size..].to_string();
153 } else {
154 current_chunk.clear();
155 }
156 }
157 }
158
159 if !current_chunk.is_empty() {
161 let tokenized = self.tokenize_chunk(¤t_chunk)?;
162 chunks.push(tokenized);
163 }
164
165 Ok(chunks)
166 }
167
168 fn tokenize_chunk(&self, text: &str) -> Result<TokenizedInput> {
170 let mut tokenized = self.tokenizer.encode(text)?;
171
172 if let Some(max_len) = self.max_chunk_length {
174 if tokenized.input_ids.len() > max_len {
175 tokenized.input_ids.truncate(max_len);
176 tokenized.attention_mask.truncate(max_len);
177 if let Some(ref mut token_type_ids) = tokenized.token_type_ids {
178 token_type_ids.truncate(max_len);
179 }
180 }
181 }
182
183 Ok(tokenized)
184 }
185
186 pub fn tokenizer(&self) -> &T {
188 &self.tokenizer
189 }
190
191 pub fn buffer_size(&self) -> usize {
193 self.buffer_size
194 }
195
196 pub fn overlap_size(&self) -> usize {
198 self.overlap_size
199 }
200
201 pub fn max_chunk_length(&self) -> Option<usize> {
203 self.max_chunk_length
204 }
205}
206
207pub struct BatchedStreamingTokenizer<T: Tokenizer> {
209 streaming_tokenizer: StreamingTokenizer<T>,
210 batch_size: usize,
211}
212
213impl<T: Tokenizer> BatchedStreamingTokenizer<T> {
214 pub fn new(tokenizer: T, batch_size: usize) -> Self {
216 Self {
217 streaming_tokenizer: StreamingTokenizer::new(tokenizer),
218 batch_size,
219 }
220 }
221
222 pub fn with_streaming_params(mut self, buffer_size: usize, overlap_size: usize) -> Self {
224 self.streaming_tokenizer = self
225 .streaming_tokenizer
226 .with_buffer_size(buffer_size)
227 .with_overlap_size(overlap_size);
228 self
229 }
230
231 pub fn with_max_chunk_length(mut self, max_length: usize) -> Self {
233 self.streaming_tokenizer = self.streaming_tokenizer.with_max_chunk_length(max_length);
234 self
235 }
236
237 pub fn process_text_batch(&self, texts: &[String]) -> Result<Vec<Vec<TokenizedInput>>> {
239 let mut results = Vec::new();
240
241 for batch in texts.chunks(self.batch_size) {
242 let mut batch_results = Vec::new();
243 for text in batch {
244 let tokenized_chunks = self.streaming_tokenizer.process_text(text)?;
245 batch_results.push(tokenized_chunks);
246 }
247 results.extend(batch_results);
248 }
249
250 Ok(results)
251 }
252
253 pub fn batch_size(&self) -> usize {
255 self.batch_size
256 }
257
258 pub fn streaming_tokenizer(&self) -> &StreamingTokenizer<T> {
260 &self.streaming_tokenizer
261 }
262}
263
264pub struct TextFileIterator<R: BufRead> {
266 reader: R,
267 buffer: String,
268 chunk_size: usize,
269 #[allow(dead_code)]
270 overlap_size: usize,
271 eof: bool,
272}
273
274impl<R: BufRead> TextFileIterator<R> {
275 pub fn new(reader: R, chunk_size: usize, overlap_size: usize) -> Self {
277 Self {
278 reader,
279 buffer: String::new(),
280 chunk_size,
281 overlap_size,
282 eof: false,
283 }
284 }
285
286 pub fn next_chunk(&mut self) -> AnyhowResult<Option<String>> {
288 if self.eof {
289 return Ok(None);
290 }
291
292 self.buffer.clear();
293
294 let mut bytes_read = 0;
296 let mut temp_buf = String::new();
297
298 while bytes_read < self.chunk_size {
299 temp_buf.clear();
300 let n = self.reader.read_line(&mut temp_buf)?;
301 if n == 0 {
302 self.eof = true;
303 break;
304 }
305 self.buffer.push_str(&temp_buf);
306 bytes_read += n;
307 }
308
309 if self.buffer.is_empty() {
310 Ok(None)
311 } else {
312 Ok(Some(self.buffer.clone()))
313 }
314 }
315}
316
317impl<R: BufRead> Iterator for TextFileIterator<R> {
318 type Item = AnyhowResult<String>;
319
320 fn next(&mut self) -> Option<Self::Item> {
321 match self.next_chunk() {
322 Ok(Some(chunk)) => Some(Ok(chunk)),
323 Ok(None) => None,
324 Err(e) => Some(Err(e)),
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use crate::char::CharTokenizer;
333 use std::io::Cursor;
334
335 fn create_test_tokenizer() -> CharTokenizer {
336 let mut vocab = std::collections::HashMap::new();
337 vocab.insert("a".to_string(), 0);
338 vocab.insert("b".to_string(), 1);
339 vocab.insert("c".to_string(), 2);
340 vocab.insert(" ".to_string(), 3);
341 CharTokenizer::new(vocab)
342 }
343
344 #[test]
345 fn test_streaming_tokenizer_basic() {
346 let tokenizer = create_test_tokenizer();
347 let streaming = StreamingTokenizer::new(tokenizer);
348
349 let text = "Hello world! This is a test of streaming tokenization.";
350 let chunks = streaming.process_text(text).expect("Operation failed in test");
351
352 assert!(!chunks.is_empty());
353 for chunk in chunks {
355 assert!(!chunk.input_ids.is_empty());
356 assert!(!chunk.attention_mask.is_empty());
357 }
358 }
359
360 #[test]
361 fn test_streaming_tokenizer_with_params() {
362 let tokenizer = create_test_tokenizer();
363 let streaming = StreamingTokenizer::new(tokenizer)
364 .with_buffer_size(50)
365 .with_overlap_size(10)
366 .with_max_chunk_length(20);
367
368 let text = "This is a longer text that should be split into multiple chunks based on the buffer size.";
369 let chunks = streaming.process_text(text).expect("Operation failed in test");
370
371 assert!(chunks.len() > 1);
372
373 for chunk in chunks {
375 assert!(chunk.input_ids.len() <= 20);
376 }
377 }
378
379 #[test]
380 fn test_streaming_tokenizer_from_reader() {
381 let tokenizer = create_test_tokenizer();
382 let streaming = StreamingTokenizer::new(tokenizer);
383
384 let text = "Line 1\nLine 2\nLine 3\n";
385 let cursor = Cursor::new(text.as_bytes());
386 let chunks = streaming.process_stream(cursor).expect("Operation failed in test");
387
388 assert!(!chunks.is_empty());
389 for chunk in chunks {
390 assert!(!chunk.input_ids.is_empty());
391 }
392 }
393
394 #[test]
395 fn test_streaming_tokenizer_lines() {
396 let tokenizer = create_test_tokenizer();
397 let streaming = StreamingTokenizer::new(tokenizer).with_buffer_size(20);
398
399 let lines = vec![
400 "First line".to_string(),
401 "Second line".to_string(),
402 "Third line".to_string(),
403 ];
404
405 let chunks = streaming.process_lines(lines.into_iter()).expect("Operation failed in test");
406 assert!(!chunks.is_empty());
407 }
408
409 #[test]
410 fn test_batched_streaming_tokenizer() {
411 let tokenizer = create_test_tokenizer();
412 let batched = BatchedStreamingTokenizer::new(tokenizer, 2).with_streaming_params(50, 10);
413
414 let texts = vec![
415 "First text to tokenize".to_string(),
416 "Second text to tokenize".to_string(),
417 "Third text to tokenize".to_string(),
418 ];
419
420 let results = batched.process_text_batch(&texts).expect("Operation failed in test");
421 assert_eq!(results.len(), 3);
422
423 for result in results {
424 assert!(!result.is_empty());
425 for chunk in result {
426 assert!(!chunk.input_ids.is_empty());
427 }
428 }
429 }
430
431 #[test]
432 fn test_text_file_iterator() {
433 let text = "Line 1\nLine 2\nLine 3\nLine 4\n";
434 let cursor = Cursor::new(text.as_bytes());
435 let buf_reader = BufReader::new(cursor);
436
437 let iterator = TextFileIterator::new(buf_reader, 10, 2);
438
439 let chunks: std::result::Result<Vec<_>, _> = iterator.collect();
440 let chunks = chunks.expect("Operation failed in test");
441
442 assert!(!chunks.is_empty());
443 for chunk in chunks {
444 assert!(!chunk.is_empty());
445 }
446 }
447
448 #[test]
449 fn test_streaming_empty_text() {
450 let tokenizer = create_test_tokenizer();
451 let streaming = StreamingTokenizer::new(tokenizer);
452
453 let chunks = streaming.process_text("").expect("Operation failed in test");
454 assert_eq!(chunks.len(), 1); assert!(chunks[0].input_ids.is_empty() || chunks[0].input_ids.len() == 1);
456 }
458
459 #[test]
460 fn test_streaming_configuration() {
461 let tokenizer = create_test_tokenizer();
462 let streaming = StreamingTokenizer::new(tokenizer)
463 .with_buffer_size(1024)
464 .with_overlap_size(128)
465 .with_max_chunk_length(512);
466
467 assert_eq!(streaming.buffer_size(), 1024);
468 assert_eq!(streaming.overlap_size(), 128);
469 assert_eq!(streaming.max_chunk_length(), Some(512));
470 }
471}