Skip to main content

trustformers_tokenizers/
streaming.rs

1use anyhow::Result as AnyhowResult;
2use std::io::{BufRead, BufReader, Read};
3use trustformers_core::errors::Result;
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6/// Streaming tokenizer for processing large texts efficiently
7pub struct StreamingTokenizer<T: Tokenizer> {
8    tokenizer: T,
9    buffer_size: usize,
10    overlap_size: usize,
11    max_chunk_length: Option<usize>,
12}
13
14impl<T: Tokenizer> StreamingTokenizer<T> {
15    /// Create a new streaming tokenizer
16    pub fn new(tokenizer: T) -> Self {
17        Self {
18            tokenizer,
19            buffer_size: 8192, // 8KB buffer
20            overlap_size: 256, // 256 chars overlap between chunks
21            max_chunk_length: None,
22        }
23    }
24
25    /// Set the buffer size for reading from stream
26    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
27        self.buffer_size = buffer_size;
28        self
29    }
30
31    /// Set the overlap size between chunks
32    pub fn with_overlap_size(mut self, overlap_size: usize) -> Self {
33        self.overlap_size = overlap_size;
34        self
35    }
36
37    /// Set maximum chunk length for tokenization
38    pub fn with_max_chunk_length(mut self, max_length: usize) -> Self {
39        self.max_chunk_length = Some(max_length);
40        self
41    }
42
43    /// Process a stream of text and return tokenized chunks
44    pub fn process_stream<R: Read>(&self, reader: R) -> Result<Vec<TokenizedInput>> {
45        let mut buf_reader = BufReader::with_capacity(self.buffer_size, reader);
46        let mut chunks = Vec::new();
47        let mut buffer = String::new();
48        let mut previous_overlap = String::new();
49
50        loop {
51            buffer.clear();
52            let bytes_read = buf_reader.read_line(&mut buffer).map_err(|e| {
53                trustformers_core::errors::TrustformersError::other(format!("I/O error: {}", e))
54            })?;
55
56            if bytes_read == 0 {
57                break; // End of stream
58            }
59
60            // Combine with previous overlap
61            let full_text = if previous_overlap.is_empty() {
62                buffer.clone()
63            } else {
64                format!("{}{}", previous_overlap, buffer)
65            };
66
67            // Tokenize the chunk
68            let tokenized = self.tokenize_chunk(&full_text)?;
69            chunks.push(tokenized);
70
71            // Prepare overlap for next chunk
72            if full_text.len() > self.overlap_size {
73                previous_overlap = full_text[full_text.len() - self.overlap_size..].to_string();
74            } else {
75                previous_overlap.clear();
76            }
77        }
78
79        Ok(chunks)
80    }
81
82    /// Process text from a string in streaming fashion
83    pub fn process_text(&self, text: &str) -> Result<Vec<TokenizedInput>> {
84        let mut chunks = Vec::new();
85        let mut start = 0;
86        let chunk_size = self.buffer_size;
87
88        // Handle empty text case
89        if text.is_empty() {
90            let empty_chunk = self.tokenize_chunk("")?;
91            chunks.push(empty_chunk);
92            return Ok(chunks);
93        }
94
95        while start < text.len() {
96            let end = std::cmp::min(start + chunk_size, text.len());
97            let mut chunk_end = end;
98
99            // Try to end at a word boundary if possible
100            if end < text.len() {
101                if let Some(last_space) = text[start..end].rfind(' ') {
102                    chunk_end = start + last_space;
103                }
104            }
105
106            // Ensure we always make progress to avoid infinite loops
107            if chunk_end <= start {
108                chunk_end = std::cmp::min(start + 1, text.len());
109            }
110
111            let chunk_text = &text[start..chunk_end];
112            let tokenized = self.tokenize_chunk(chunk_text)?;
113            chunks.push(tokenized);
114
115            // Move start with overlap, ensuring we always advance
116            let next_start = if chunk_end > self.overlap_size {
117                chunk_end - self.overlap_size
118            } else {
119                chunk_end
120            };
121
122            // Ensure we always advance at least one position to avoid infinite loop
123            start = std::cmp::max(next_start, start + 1);
124        }
125
126        Ok(chunks)
127    }
128
129    /// Process an iterator of text lines
130    pub fn process_lines<I>(&self, lines: I) -> Result<Vec<TokenizedInput>>
131    where
132        I: Iterator<Item = String>,
133    {
134        let mut chunks = Vec::new();
135        let mut current_chunk = String::new();
136
137        for line in lines {
138            // Add line to current chunk
139            if !current_chunk.is_empty() {
140                current_chunk.push('\n');
141            }
142            current_chunk.push_str(&line);
143
144            // If chunk is large enough, tokenize it
145            if current_chunk.len() >= self.buffer_size {
146                let tokenized = self.tokenize_chunk(&current_chunk)?;
147                chunks.push(tokenized);
148
149                // Keep overlap
150                if current_chunk.len() > self.overlap_size {
151                    current_chunk =
152                        current_chunk[current_chunk.len() - self.overlap_size..].to_string();
153                } else {
154                    current_chunk.clear();
155                }
156            }
157        }
158
159        // Process remaining chunk
160        if !current_chunk.is_empty() {
161            let tokenized = self.tokenize_chunk(&current_chunk)?;
162            chunks.push(tokenized);
163        }
164
165        Ok(chunks)
166    }
167
168    /// Tokenize a single chunk with length constraints
169    fn tokenize_chunk(&self, text: &str) -> Result<TokenizedInput> {
170        let mut tokenized = self.tokenizer.encode(text)?;
171
172        // Apply max chunk length if specified
173        if let Some(max_len) = self.max_chunk_length {
174            if tokenized.input_ids.len() > max_len {
175                tokenized.input_ids.truncate(max_len);
176                tokenized.attention_mask.truncate(max_len);
177                if let Some(ref mut token_type_ids) = tokenized.token_type_ids {
178                    token_type_ids.truncate(max_len);
179                }
180            }
181        }
182
183        Ok(tokenized)
184    }
185
186    /// Get the underlying tokenizer
187    pub fn tokenizer(&self) -> &T {
188        &self.tokenizer
189    }
190
191    /// Get buffer size
192    pub fn buffer_size(&self) -> usize {
193        self.buffer_size
194    }
195
196    /// Get overlap size
197    pub fn overlap_size(&self) -> usize {
198        self.overlap_size
199    }
200
201    /// Get max chunk length
202    pub fn max_chunk_length(&self) -> Option<usize> {
203        self.max_chunk_length
204    }
205}
206
207/// Batched streaming tokenizer for processing multiple streams
208pub struct BatchedStreamingTokenizer<T: Tokenizer> {
209    streaming_tokenizer: StreamingTokenizer<T>,
210    batch_size: usize,
211}
212
213impl<T: Tokenizer> BatchedStreamingTokenizer<T> {
214    /// Create a new batched streaming tokenizer
215    pub fn new(tokenizer: T, batch_size: usize) -> Self {
216        Self {
217            streaming_tokenizer: StreamingTokenizer::new(tokenizer),
218            batch_size,
219        }
220    }
221
222    /// Set streaming parameters
223    pub fn with_streaming_params(mut self, buffer_size: usize, overlap_size: usize) -> Self {
224        self.streaming_tokenizer = self
225            .streaming_tokenizer
226            .with_buffer_size(buffer_size)
227            .with_overlap_size(overlap_size);
228        self
229    }
230
231    /// Set max chunk length
232    pub fn with_max_chunk_length(mut self, max_length: usize) -> Self {
233        self.streaming_tokenizer = self.streaming_tokenizer.with_max_chunk_length(max_length);
234        self
235    }
236
237    /// Process multiple text streams in batches
238    pub fn process_text_batch(&self, texts: &[String]) -> Result<Vec<Vec<TokenizedInput>>> {
239        let mut results = Vec::new();
240
241        for batch in texts.chunks(self.batch_size) {
242            let mut batch_results = Vec::new();
243            for text in batch {
244                let tokenized_chunks = self.streaming_tokenizer.process_text(text)?;
245                batch_results.push(tokenized_chunks);
246            }
247            results.extend(batch_results);
248        }
249
250        Ok(results)
251    }
252
253    /// Get batch size
254    pub fn batch_size(&self) -> usize {
255        self.batch_size
256    }
257
258    /// Get the underlying streaming tokenizer
259    pub fn streaming_tokenizer(&self) -> &StreamingTokenizer<T> {
260        &self.streaming_tokenizer
261    }
262}
263
264/// Memory-efficient text iterator for large files
265pub struct TextFileIterator<R: BufRead> {
266    reader: R,
267    buffer: String,
268    chunk_size: usize,
269    #[allow(dead_code)]
270    overlap_size: usize,
271    eof: bool,
272}
273
274impl<R: BufRead> TextFileIterator<R> {
275    /// Create a new text file iterator
276    pub fn new(reader: R, chunk_size: usize, overlap_size: usize) -> Self {
277        Self {
278            reader,
279            buffer: String::new(),
280            chunk_size,
281            overlap_size,
282            eof: false,
283        }
284    }
285
286    /// Read next chunk from the file
287    pub fn next_chunk(&mut self) -> AnyhowResult<Option<String>> {
288        if self.eof {
289            return Ok(None);
290        }
291
292        self.buffer.clear();
293
294        // Read chunk_size bytes
295        let mut bytes_read = 0;
296        let mut temp_buf = String::new();
297
298        while bytes_read < self.chunk_size {
299            temp_buf.clear();
300            let n = self.reader.read_line(&mut temp_buf)?;
301            if n == 0 {
302                self.eof = true;
303                break;
304            }
305            self.buffer.push_str(&temp_buf);
306            bytes_read += n;
307        }
308
309        if self.buffer.is_empty() {
310            Ok(None)
311        } else {
312            Ok(Some(self.buffer.clone()))
313        }
314    }
315}
316
317impl<R: BufRead> Iterator for TextFileIterator<R> {
318    type Item = AnyhowResult<String>;
319
320    fn next(&mut self) -> Option<Self::Item> {
321        match self.next_chunk() {
322            Ok(Some(chunk)) => Some(Ok(chunk)),
323            Ok(None) => None,
324            Err(e) => Some(Err(e)),
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use crate::char::CharTokenizer;
333    use std::io::Cursor;
334
335    fn create_test_tokenizer() -> CharTokenizer {
336        let mut vocab = std::collections::HashMap::new();
337        vocab.insert("a".to_string(), 0);
338        vocab.insert("b".to_string(), 1);
339        vocab.insert("c".to_string(), 2);
340        vocab.insert(" ".to_string(), 3);
341        CharTokenizer::new(vocab)
342    }
343
344    #[test]
345    fn test_streaming_tokenizer_basic() {
346        let tokenizer = create_test_tokenizer();
347        let streaming = StreamingTokenizer::new(tokenizer);
348
349        let text = "Hello world! This is a test of streaming tokenization.";
350        let chunks = streaming.process_text(text).expect("Operation failed in test");
351
352        assert!(!chunks.is_empty());
353        // Each chunk should have tokenized content
354        for chunk in chunks {
355            assert!(!chunk.input_ids.is_empty());
356            assert!(!chunk.attention_mask.is_empty());
357        }
358    }
359
360    #[test]
361    fn test_streaming_tokenizer_with_params() {
362        let tokenizer = create_test_tokenizer();
363        let streaming = StreamingTokenizer::new(tokenizer)
364            .with_buffer_size(50)
365            .with_overlap_size(10)
366            .with_max_chunk_length(20);
367
368        let text = "This is a longer text that should be split into multiple chunks based on the buffer size.";
369        let chunks = streaming.process_text(text).expect("Operation failed in test");
370
371        assert!(chunks.len() > 1);
372
373        // Check max chunk length constraint
374        for chunk in chunks {
375            assert!(chunk.input_ids.len() <= 20);
376        }
377    }
378
379    #[test]
380    fn test_streaming_tokenizer_from_reader() {
381        let tokenizer = create_test_tokenizer();
382        let streaming = StreamingTokenizer::new(tokenizer);
383
384        let text = "Line 1\nLine 2\nLine 3\n";
385        let cursor = Cursor::new(text.as_bytes());
386        let chunks = streaming.process_stream(cursor).expect("Operation failed in test");
387
388        assert!(!chunks.is_empty());
389        for chunk in chunks {
390            assert!(!chunk.input_ids.is_empty());
391        }
392    }
393
394    #[test]
395    fn test_streaming_tokenizer_lines() {
396        let tokenizer = create_test_tokenizer();
397        let streaming = StreamingTokenizer::new(tokenizer).with_buffer_size(20);
398
399        let lines = vec![
400            "First line".to_string(),
401            "Second line".to_string(),
402            "Third line".to_string(),
403        ];
404
405        let chunks = streaming.process_lines(lines.into_iter()).expect("Operation failed in test");
406        assert!(!chunks.is_empty());
407    }
408
409    #[test]
410    fn test_batched_streaming_tokenizer() {
411        let tokenizer = create_test_tokenizer();
412        let batched = BatchedStreamingTokenizer::new(tokenizer, 2).with_streaming_params(50, 10);
413
414        let texts = vec![
415            "First text to tokenize".to_string(),
416            "Second text to tokenize".to_string(),
417            "Third text to tokenize".to_string(),
418        ];
419
420        let results = batched.process_text_batch(&texts).expect("Operation failed in test");
421        assert_eq!(results.len(), 3);
422
423        for result in results {
424            assert!(!result.is_empty());
425            for chunk in result {
426                assert!(!chunk.input_ids.is_empty());
427            }
428        }
429    }
430
431    #[test]
432    fn test_text_file_iterator() {
433        let text = "Line 1\nLine 2\nLine 3\nLine 4\n";
434        let cursor = Cursor::new(text.as_bytes());
435        let buf_reader = BufReader::new(cursor);
436
437        let iterator = TextFileIterator::new(buf_reader, 10, 2);
438
439        let chunks: std::result::Result<Vec<_>, _> = iterator.collect();
440        let chunks = chunks.expect("Operation failed in test");
441
442        assert!(!chunks.is_empty());
443        for chunk in chunks {
444            assert!(!chunk.is_empty());
445        }
446    }
447
448    #[test]
449    fn test_streaming_empty_text() {
450        let tokenizer = create_test_tokenizer();
451        let streaming = StreamingTokenizer::new(tokenizer);
452
453        let chunks = streaming.process_text("").expect("Operation failed in test");
454        assert_eq!(chunks.len(), 1); // Should have one empty chunk
455        assert!(chunks[0].input_ids.is_empty() || chunks[0].input_ids.len() == 1);
456        // Might have just padding
457    }
458
459    #[test]
460    fn test_streaming_configuration() {
461        let tokenizer = create_test_tokenizer();
462        let streaming = StreamingTokenizer::new(tokenizer)
463            .with_buffer_size(1024)
464            .with_overlap_size(128)
465            .with_max_chunk_length(512);
466
467        assert_eq!(streaming.buffer_size(), 1024);
468        assert_eq!(streaming.overlap_size(), 128);
469        assert_eq!(streaming.max_chunk_length(), Some(512));
470    }
471}